diff --git a/foundry/interfaces/ICallback.sol b/foundry/interfaces/ICallback.sol index e71527e..408a5fe 100644 --- a/foundry/interfaces/ICallback.sol +++ b/foundry/interfaces/ICallback.sol @@ -2,7 +2,24 @@ pragma solidity ^0.8.26; interface ICallback { + /** + * @notice Handles callback data from a protocol or contract interaction. + * @dev This method processes callback data and returns a result. Implementations + * should handle the specific callback logic required by the protocol. + * + * @param data The encoded callback data to be processed. + * @return result The encoded result of the callback processing. + */ function handleCallback( bytes calldata data ) external returns (bytes memory result); + + /** + * @notice Verifies the validity of callback data. + * @dev This view function checks if the provided callback data is valid according + * to the protocol's requirements. It should revert if the data is invalid. + * + * @param data The encoded callback data to verify. + */ + function verifyCallback(bytes calldata data) external view; } diff --git a/foundry/src/Dispatcher.sol b/foundry/src/Dispatcher.sol index b0a69c1..d23321c 100644 --- a/foundry/src/Dispatcher.sol +++ b/foundry/src/Dispatcher.sol @@ -6,6 +6,7 @@ import "@interfaces/ICallback.sol"; error Dispatcher__UnapprovedExecutor(); error Dispatcher__NonContractExecutor(); +error Dispatcher__InvalidDataLength(); /** * @title Dispatcher - Dispatch execution to external contracts @@ -81,16 +82,11 @@ contract Dispatcher { calculatedAmount = abi.decode(result, (uint256)); } - function _handleCallback(bytes4 selector, bytes memory data) internal { - // Using assembly to access the last 20 bytes of the bytes memory data - address executor; - // slither-disable-next-line assembly - assembly { - let pos := sub(add(add(data, 0x20), mload(data)), 20) - executor := mload(pos) - executor := shr(96, executor) - } - + function _handleCallback( + bytes4 selector, + address executor, + bytes memory data + ) internal { if (!executors[executor]) { revert Dispatcher__UnapprovedExecutor(); } diff --git a/foundry/src/TychoRouter.sol b/foundry/src/TychoRouter.sol index bfdcdd3..2ec2f21 100644 --- a/foundry/src/TychoRouter.sol +++ b/foundry/src/TychoRouter.sol @@ -21,6 +21,7 @@ error TychoRouter__EmptySwaps(); error TychoRouter__NegativeSlippage(uint256 amount, uint256 minAmount); error TychoRouter__AmountInNotFullySpent(uint256 leftoverAmount); error TychoRouter__MessageValueMismatch(uint256 value, uint256 amount); +error TychoRouter__InvalidDataLength(); contract TychoRouter is AccessControl, Dispatcher, Pausable, ReentrancyGuard { IAllowanceTransfer public immutable permit2; @@ -215,12 +216,14 @@ contract TychoRouter is AccessControl, Dispatcher, Pausable, ReentrancyGuard { /** * @dev We use the fallback function to allow flexibility on callback. - * This function will static call a verifier contract and should revert if the - * caller is not a pool. */ fallback() external { - bytes4 selector = bytes4(msg.data[:4]); - _handleCallback(selector, msg.data[4:]); + address executor = + address(uint160(bytes20(msg.data[msg.data.length - 20:]))); + bytes4 selector = + bytes4(msg.data[msg.data.length - 24:msg.data.length - 20]); + bytes memory protocolData = msg.data[:msg.data.length - 24]; + _handleCallback(selector, executor, protocolData); } /** @@ -364,22 +367,31 @@ contract TychoRouter is AccessControl, Dispatcher, Pausable, ReentrancyGuard { function uniswapV3SwapCallback( int256 amount0Delta, int256 amount1Delta, - bytes calldata msgData + bytes calldata data ) external { + if (data.length < 24) revert TychoRouter__InvalidDataLength(); + address executor = address(uint160(bytes20(data[data.length - 20:]))); + bytes4 selector = bytes4(data[data.length - 24:data.length - 20]); + bytes memory protocolData = data[:data.length - 24]; _handleCallback( - bytes4(0), abi.encodePacked(amount0Delta, amount1Delta, msgData) + selector, + executor, + abi.encodePacked(amount0Delta, amount1Delta, protocolData) ); } + /** + * @dev Called by UniswapV4 pool manager after achieving unlock state. + */ function unlockCallback(bytes calldata data) external returns (bytes memory) { - require(data.length >= 20, "Invalid data length"); - bytes4 selector = bytes4(data[data.length - 24:data.length - 20]); + if (data.length < 24) revert TychoRouter__InvalidDataLength(); address executor = address(uint160(bytes20(data[data.length - 20:]))); + bytes4 selector = bytes4(data[data.length - 24:data.length - 20]); bytes memory protocolData = data[:data.length - 24]; - _handleCallback(selector, abi.encodePacked(protocolData, executor)); + _handleCallback(selector, executor, protocolData); return ""; } } diff --git a/foundry/src/executors/UniswapV3Executor.sol b/foundry/src/executors/UniswapV3Executor.sol index 74ff0e4..681edb2 100644 --- a/foundry/src/executors/UniswapV3Executor.sol +++ b/foundry/src/executors/UniswapV3Executor.sol @@ -29,11 +29,10 @@ contract UniswapV3Executor is IExecutor, ICallback { } // slither-disable-next-line locked-ether - function swap(uint256 amountIn, bytes calldata data) - external - payable - returns (uint256 amountOut) - { + function swap( + uint256 amountIn, + bytes calldata data + ) external payable returns (uint256 amountOut) { ( address tokenIn, address tokenOut, @@ -76,7 +75,9 @@ contract UniswapV3Executor is IExecutor, ICallback { abi.encodeWithSelector( ICallback.handleCallback.selector, abi.encodePacked( - amount0Delta, amount1Delta, data[:data.length - 20] + amount0Delta, + amount1Delta, + data[:data.length - 20] ) ) ); @@ -91,28 +92,43 @@ contract UniswapV3Executor is IExecutor, ICallback { } } - function handleCallback(bytes calldata msgData) - external - returns (bytes memory result) - { - (int256 amount0Delta, int256 amount1Delta) = - abi.decode(msgData[:64], (int256, int256)); + function handleCallback( + bytes calldata msgData + ) external returns (bytes memory result) { + (int256 amount0Delta, int256 amount1Delta) = abi.decode( + msgData[:64], + (int256, int256) + ); address tokenIn = address(bytes20(msgData[64:84])); - address tokenOut = address(bytes20(msgData[84:104])); - uint24 poolFee = uint24(bytes3(msgData[104:107])); - // slither-disable-next-line unused-return - CallbackValidationV2.verifyCallback(factory, tokenIn, tokenOut, poolFee); + verifyCallback(msgData[64:]); - uint256 amountOwed = - amount0Delta > 0 ? uint256(amount0Delta) : uint256(amount1Delta); + uint256 amountOwed = amount0Delta > 0 + ? uint256(amount0Delta) + : uint256(amount1Delta); IERC20(tokenIn).safeTransfer(msg.sender, amountOwed); return abi.encode(amountOwed, tokenIn); } - function _decodeData(bytes calldata data) + function verifyCallback(bytes calldata data) public view { + address tokenIn = address(bytes20(data[0:20])); + address tokenOut = address(bytes20(data[20:40])); + uint24 poolFee = uint24(bytes3(data[40:43])); + + // slither-disable-next-line unused-return + CallbackValidationV2.verifyCallback( + factory, + tokenIn, + tokenOut, + poolFee + ); + } + + function _decodeData( + bytes calldata data + ) internal pure returns ( @@ -135,11 +151,18 @@ contract UniswapV3Executor is IExecutor, ICallback { zeroForOne = uint8(data[83]) > 0; } - function _makeV3CallbackData(address tokenIn, address tokenOut, uint24 fee) - internal - view - returns (bytes memory) - { - return abi.encodePacked(tokenIn, tokenOut, fee, self); + function _makeV3CallbackData( + address tokenIn, + address tokenOut, + uint24 fee + ) internal view returns (bytes memory) { + return + abi.encodePacked( + tokenIn, + tokenOut, + fee, + ICallback.handleCallback.selector, + self + ); } }