diff --git a/foundry/src/executors/UniswapV3Executor.sol b/foundry/src/executors/UniswapV3Executor.sol index 7b6383f..e243dc9 100644 --- a/foundry/src/executors/UniswapV3Executor.sol +++ b/foundry/src/executors/UniswapV3Executor.sol @@ -5,13 +5,14 @@ import "@interfaces/IExecutor.sol"; import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import "@uniswap/v3-core/contracts/interfaces/IUniswapV3Pool.sol"; import "@interfaces/ICallback.sol"; +import {ExecutorTransferMethods} from "./ExecutorTransferMethods.sol"; error UniswapV3Executor__InvalidDataLength(); error UniswapV3Executor__InvalidFactory(); error UniswapV3Executor__InvalidTarget(); error UniswapV3Executor__InvalidInitCode(); -contract UniswapV3Executor is IExecutor, ICallback { +contract UniswapV3Executor is IExecutor, ICallback, ExecutorTransferMethods { using SafeERC20 for IERC20; uint160 private constant MIN_SQRT_RATIO = 4295128739; @@ -22,7 +23,9 @@ contract UniswapV3Executor is IExecutor, ICallback { bytes32 public immutable initCode; address private immutable self; - constructor(address _factory, bytes32 _initCode) { + constructor(address _factory, bytes32 _initCode, address _permit2) + ExecutorTransferMethods(_permit2) + { if (_factory == address(0)) { revert UniswapV3Executor__InvalidFactory(); } @@ -46,7 +49,8 @@ contract UniswapV3Executor is IExecutor, ICallback { uint24 fee, address receiver, address target, - bool zeroForOne + bool zeroForOne, + TransferMethod method ) = _decodeData(data); _verifyPairAddress(tokenIn, tokenOut, fee, target); @@ -55,7 +59,8 @@ contract UniswapV3Executor is IExecutor, ICallback { int256 amount1; IUniswapV3Pool pool = IUniswapV3Pool(target); - bytes memory callbackData = _makeV3CallbackData(tokenIn, tokenOut, fee); + bytes memory callbackData = + _makeV3CallbackData(tokenIn, tokenOut, fee, method); { (amount0, amount1) = pool.swap( @@ -92,12 +97,20 @@ contract UniswapV3Executor is IExecutor, ICallback { address tokenIn = address(bytes20(msgData[132:152])); + require( + uint8(msgData[171]) <= uint8(TransferMethod.NONE), + "InvalidTransferMethod" + ); + TransferMethod method = TransferMethod(uint8(msgData[171])); + verifyCallback(msgData[132:]); uint256 amountOwed = amount0Delta > 0 ? uint256(amount0Delta) : uint256(amount1Delta); - IERC20(tokenIn).safeTransfer(msg.sender, amountOwed); + // TODO This must never be a safeTransfer. Figure out how to ensure this. + _transfer(IERC20(tokenIn), msg.sender, amountOwed, method); + return abi.encode(amountOwed, tokenIn); } @@ -132,7 +145,8 @@ contract UniswapV3Executor is IExecutor, ICallback { uint24 fee, address receiver, address target, - bool zeroForOne + bool zeroForOne, + TransferMethod method ) { if (data.length != 84) { @@ -144,14 +158,16 @@ contract UniswapV3Executor is IExecutor, ICallback { receiver = address(bytes20(data[43:63])); target = address(bytes20(data[63:83])); zeroForOne = uint8(data[83]) > 0; + method = TransferMethod.TRANSFER; } - function _makeV3CallbackData(address tokenIn, address tokenOut, uint24 fee) - internal - pure - returns (bytes memory) - { - return abi.encodePacked(tokenIn, tokenOut, fee); + function _makeV3CallbackData( + address tokenIn, + address tokenOut, + uint24 fee, + TransferMethod method + ) internal pure returns (bytes memory) { + return abi.encodePacked(tokenIn, tokenOut, fee, uint8(method), self); } function _verifyPairAddress( diff --git a/foundry/test/TychoRouterTestSetup.sol b/foundry/test/TychoRouterTestSetup.sol index 89ea394..acf5a7a 100644 --- a/foundry/test/TychoRouterTestSetup.sol +++ b/foundry/test/TychoRouterTestSetup.sol @@ -98,7 +98,7 @@ contract TychoRouterTestSetup is Constants { IPoolManager poolManager = IPoolManager(poolManagerAddress); usv2Executor = new UniswapV2Executor(factoryV2, initCodeV2, PERMIT2_ADDRESS); - usv3Executor = new UniswapV3Executor(factoryV3, initCodeV3); + usv3Executor = new UniswapV3Executor(factoryV3, initCodeV3, PERMIT2_ADDRESS); usv4Executor = new UniswapV4Executor(poolManager); pancakev3Executor = new UniswapV3Executor(factoryPancakeV3, initCodePancakeV3); diff --git a/foundry/test/executors/UniswapV2Executor.t.sol b/foundry/test/executors/UniswapV2Executor.t.sol index 3f303fd..4a341f8 100644 --- a/foundry/test/executors/UniswapV2Executor.t.sol +++ b/foundry/test/executors/UniswapV2Executor.t.sol @@ -87,7 +87,10 @@ contract UniswapV2ExecutorTest is Test, Constants { assertEq(target, address(2)); assertEq(receiver, address(3)); assertEq(zeroForOne, false); - assertEq(0, uint8(method)); + assertEq( + uint8(ExecutorTransferMethods.TransferMethod.TRANSFER), + uint8(method) + ); } function testDecodeParamsInvalidDataLength() public { diff --git a/foundry/test/executors/UniswapV3Executor.t.sol b/foundry/test/executors/UniswapV3Executor.t.sol index 59b03ba..2033edf 100644 --- a/foundry/test/executors/UniswapV3Executor.t.sol +++ b/foundry/test/executors/UniswapV3Executor.t.sol @@ -6,8 +6,8 @@ import {Test} from "../../lib/forge-std/src/Test.sol"; import {Constants} from "../Constants.sol"; contract UniswapV3ExecutorExposed is UniswapV3Executor { - constructor(address _factory, bytes32 _initCode) - UniswapV3Executor(_factory, _initCode) + constructor(address _factory, bytes32 _initCode, address _permit2) + UniswapV3Executor(_factory, _initCode, _permit2) {} function decodeData(bytes calldata data) @@ -19,7 +19,8 @@ contract UniswapV3ExecutorExposed is UniswapV3Executor { uint24 fee, address receiver, address target, - bool zeroForOne + bool zeroForOne, + TransferMethod method ) { return _decodeData(data); @@ -48,10 +49,10 @@ contract UniswapV3ExecutorTest is Test, Constants { vm.createSelectFork(vm.rpcUrl("mainnet"), forkBlock); uniswapV3Exposed = new UniswapV3ExecutorExposed( - USV3_FACTORY_ETHEREUM, USV3_POOL_CODE_INIT_HASH + USV3_FACTORY_ETHEREUM, USV3_POOL_CODE_INIT_HASH, PERMIT2_ADDRESS ); pancakeV3Exposed = new UniswapV3ExecutorExposed( - PANCAKESWAPV3_DEPLOYER_ETHEREUM, PANCAKEV3_POOL_CODE_INIT_HASH + PANCAKESWAPV3_DEPLOYER_ETHEREUM, PANCAKEV3_POOL_CODE_INIT_HASH, PERMIT2_ADDRESS ); } @@ -67,7 +68,8 @@ contract UniswapV3ExecutorTest is Test, Constants { uint24 fee, address receiver, address target, - bool zeroForOne + bool zeroForOne, + ExecutorTransferMethods.TransferMethod method ) = uniswapV3Exposed.decodeData(data); assertEq(tokenIn, WETH_ADDR); @@ -76,6 +78,10 @@ contract UniswapV3ExecutorTest is Test, Constants { assertEq(receiver, address(2)); assertEq(target, address(3)); assertEq(zeroForOne, false); + assertEq( + uint8(method), + uint8(ExecutorTransferMethods.TransferMethod.TRANSFER) + ); } function testDecodeParamsInvalidDataLength() public { @@ -116,7 +122,8 @@ contract UniswapV3ExecutorTest is Test, Constants { int256(0), // amount1Delta dataOffset, dataLength, - protocolData + protocolData, + uint8(ExecutorTransferMethods.TransferMethod.TRANSFER) ); uniswapV3Exposed.handleCallback(callbackData); vm.stopPrank();