diff --git a/foundry/src/Dispatcher.sol b/foundry/src/Dispatcher.sol index d23321c..03292be 100644 --- a/foundry/src/Dispatcher.sol +++ b/foundry/src/Dispatcher.sol @@ -82,11 +82,10 @@ contract Dispatcher { calculatedAmount = abi.decode(result, (uint256)); } - function _handleCallback( - bytes4 selector, - address executor, - bytes memory data - ) internal { + function _handleCallback(bytes calldata data) internal { + bytes4 selector = bytes4(data[data.length - 4:]); + address executor = address(uint160(bytes20(data[data.length - 24:]))); + if (!executors[executor]) { revert Dispatcher__UnapprovedExecutor(); } diff --git a/foundry/src/TychoRouter.sol b/foundry/src/TychoRouter.sol index 49b9094..60a9560 100644 --- a/foundry/src/TychoRouter.sol +++ b/foundry/src/TychoRouter.sol @@ -218,12 +218,7 @@ contract TychoRouter is AccessControl, Dispatcher, Pausable, ReentrancyGuard { * @dev We use the fallback function to allow flexibility on callback. */ fallback() external { - address executor = - address(uint160(bytes20(msg.data[msg.data.length - 20:]))); - bytes4 selector = - bytes4(msg.data[msg.data.length - 24:msg.data.length - 20]); - bytes memory protocolData = msg.data[:msg.data.length - 24]; - _handleCallback(selector, executor, protocolData); + _handleCallback(msg.data); } /** @@ -365,19 +360,17 @@ contract TychoRouter is AccessControl, Dispatcher, Pausable, ReentrancyGuard { * See in IUniswapV3SwapCallback for documentation. */ function uniswapV3SwapCallback( - int256 amount0Delta, - int256 amount1Delta, + int256, /* amount0Delta */ + int256, /* amount1Delta */ bytes calldata data ) external { if (data.length < 24) revert TychoRouter__InvalidDataLength(); - bytes4 selector = bytes4(data[data.length - 4:]); - address executor = address(uint160(bytes20(data[data.length - 24:]))); - bytes memory protocolData = data[:data.length - 24]; - _handleCallback( - selector, - executor, - abi.encodePacked(amount0Delta, amount1Delta, protocolData) - ); + uint256 dataOffset = 4 + 32 + 32 + 32; // Skip selector + 2 ints + data_offset + uint256 dataLength = + uint256(bytes32(msg.data[dataOffset:dataOffset + 32])); + + bytes calldata fullData = msg.data[4:dataOffset + 32 + dataLength]; + _handleCallback(fullData); } /** @@ -388,11 +381,7 @@ contract TychoRouter is AccessControl, Dispatcher, Pausable, ReentrancyGuard { returns (bytes memory) { if (data.length < 24) revert TychoRouter__InvalidDataLength(); - bytes4 selector = bytes4(data[data.length - 4:]); - address executor = - address(uint160(bytes20(data[data.length - 24:data.length - 4]))); - bytes memory protocolData = data[:data.length - 24]; - _handleCallback(selector, executor, protocolData); + _handleCallback(data); return ""; } } diff --git a/foundry/src/executors/UniswapV3Executor.sol b/foundry/src/executors/UniswapV3Executor.sol index 4472eef..d1719fe 100644 --- a/foundry/src/executors/UniswapV3Executor.sol +++ b/foundry/src/executors/UniswapV3Executor.sol @@ -66,41 +66,23 @@ contract UniswapV3Executor is IExecutor, ICallback { } } - function uniswapV3SwapCallback( - int256 amount0Delta, - int256 amount1Delta, - bytes calldata data - ) external { - // slither-disable-next-line low-level-calls - (bool success, bytes memory result) = self.delegatecall( - abi.encodeWithSelector( - ICallback.handleCallback.selector, - abi.encodePacked( - amount0Delta, amount1Delta, data[:data.length - 20] - ) - ) - ); - if (!success) { - revert( - string( - result.length > 0 - ? result - : abi.encodePacked("Callback failed") - ) - ); - } - } - function handleCallback(bytes calldata msgData) - external + public returns (bytes memory result) { + // The data has the following layout: + // - amount0Delta (32 bytes) + // - amount1Delta (32 bytes) + // - dataOffset (32 bytes) + // - dataLength (32 bytes) + // - protocolData (variable length) + (int256 amount0Delta, int256 amount1Delta) = abi.decode(msgData[:64], (int256, int256)); - address tokenIn = address(bytes20(msgData[64:84])); + address tokenIn = address(bytes20(msgData[128:148])); - verifyCallback(msgData[64:]); + verifyCallback(msgData[128:]); uint256 amountOwed = amount0Delta > 0 ? uint256(amount0Delta) : uint256(amount1Delta); @@ -118,6 +100,20 @@ contract UniswapV3Executor is IExecutor, ICallback { CallbackValidationV2.verifyCallback(factory, tokenIn, tokenOut, poolFee); } + function uniswapV3SwapCallback( + int256, /* amount0Delta */ + int256, /* amount1Delta */ + bytes calldata /* data */ + ) external { + uint256 dataOffset = 4 + 32 + 32 + 32; // Skip selector + 2 ints + data_offset + uint256 dataLength = + uint256(bytes32(msg.data[dataOffset:dataOffset + 32])); + + bytes calldata fullData = msg.data[4:dataOffset + 32 + dataLength]; + + handleCallback(fullData); + } + function _decodeData(bytes calldata data) internal pure diff --git a/foundry/test/executors/UniswapV3Executor.t.sol b/foundry/test/executors/UniswapV3Executor.t.sol index e9a9f57..86b999b 100644 --- a/foundry/test/executors/UniswapV3Executor.t.sol +++ b/foundry/test/executors/UniswapV3Executor.t.sol @@ -76,12 +76,17 @@ contract UniswapV3ExecutorTest is Test, Constants { uint256 initialPoolReserve = IERC20(WETH_ADDR).balanceOf(DAI_WETH_USV3); vm.startPrank(DAI_WETH_USV3); + bytes memory protocolData = + abi.encodePacked(WETH_ADDR, DAI_ADDR, poolFee); + uint256 dataOffset = 3; // some offset + uint256 dataLength = protocolData.length; + bytes memory callbackData = abi.encodePacked( int256(amountOwed), // amount0Delta int256(0), // amount1Delta - WETH_ADDR, - DAI_ADDR, - poolFee + dataOffset, + dataLength, + protocolData ); uniswapV3Exposed.handleCallback(callbackData); vm.stopPrank(); @@ -90,7 +95,7 @@ contract UniswapV3ExecutorTest is Test, Constants { assertEq(finalPoolReserve - initialPoolReserve, amountOwed); } - function testSwapWETHForDAI() public { + function testSwapIntegration() public { uint256 amountIn = 10 ** 18; deal(WETH_ADDR, address(uniswapV3Exposed), amountIn);