diff --git a/script/DeploySepolia.sol b/script/DeploySepolia.sol index 6e325fa..7e20f31 100644 --- a/script/DeploySepolia.sol +++ b/script/DeploySepolia.sol @@ -279,7 +279,7 @@ contract DeploySepolia is Script { uint256 inputIndex = 0; uint256 outputIndex = n > 1 ? n - 1 : 0; uint256 maxIn = 89 * 10**6; // varied - pool.swap(msg.sender, msg.sender, inputIndex, outputIndex, maxIn, int128(0), 0, false); + pool.swap(msg.sender, bytes4(0), msg.sender, inputIndex, outputIndex, maxIn, int128(0), 0, false); // 6) Collect protocol fees now (after some swaps) so some will have been moved out pool.collectProtocolFees(); diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index 068dad2..b2c2f64 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -178,6 +178,7 @@ interface IPartyPool is IERC20Metadata, IOwnable { /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. /// Non-standard tokens (fee-on-transfer, rebasers) are rejected via balance checks. /// @param payer address of the account that pays for the swap + /// @param selector If zero, then regular ERC20 approvals must be given by the payere to the pool to move the required input amount. If this selector is nonzero, then a callback style funding mechanism is used where the given selector is invoked on the payer, passing the arguments of (address inputToken, uint256 inputAmount). The callback function must send the given amount of input coin to the pool in ordr to continue the swap transaction, otherwise "Insufficient funds" is thrown. /// @param receiver address that will receive the output tokens /// @param inputTokenIndex index of input asset /// @param outputTokenIndex index of output asset @@ -187,6 +188,7 @@ interface IPartyPool is IERC20Metadata, IOwnable { /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256), inFee fee taken from the input (uint256) function swap( address payer, + bytes4 selector, address receiver, uint256 inputTokenIndex, uint256 outputTokenIndex, diff --git a/src/PartyPool.sol b/src/PartyPool.sol index 7f685cd..e5a685e 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -248,6 +248,7 @@ contract PartyPool is PartyPoolBase, OwnableExternal, ERC20External, IPartyPool /// @inheritdoc IPartyPool function swap( address payer, + bytes4 selector, address receiver, uint256 inputTokenIndex, uint256 outputTokenIndex, @@ -266,8 +267,18 @@ contract PartyPool is PartyPoolBase, OwnableExternal, ERC20External, IPartyPool IERC20 tokenIn = _tokens[inputTokenIndex]; IERC20 tokenOut = _tokens[outputTokenIndex]; - // Transfer _tokens in via centralized helper - _receiveTokenFrom(payer, tokenIn, totalTransferAmount); + if ( selector == bytes4(0) ) + // Regular ERC20 permit of the pool to move the tokens + _receiveTokenFrom(payer, tokenIn, totalTransferAmount); + else { + // Callback-style funding mechanism + uint256 startingBalance = tokenIn.balanceOf(address(this)); + bytes memory data = abi.encodeWithSelector(selector, tokenIn, totalTransferAmount); + // Invoke the payer callback; no return value expected (reverts on failure) + Address.functionCall(payer, data); + uint256 endingBalance = tokenIn.balanceOf(address(this)); + require(endingBalance-startingBalance == totalTransferAmount, 'Insufficient funds'); + } // Compute on-chain balances as: onchain = cached + owed (+/- transfer) uint256 balIAfter = _cachedUintBalances[inputTokenIndex] + _protocolFeesOwed[inputTokenIndex] + totalTransferAmount; diff --git a/test/GasTest.sol b/test/GasTest.sol index ef75f4c..9547445 100644 --- a/test/GasTest.sol +++ b/test/GasTest.sol @@ -1,15 +1,25 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; -/* solhint-disable erc20-unchecked-transfer */ -import "forge-std/Test.sol"; -import "@abdk/ABDKMath64x64.sol"; -import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; -import "../src/LMSRStabilized.sol"; -import "../src/PartyPool.sol"; -import "../src/PartyPlanner.sol"; -import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; +import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; +import {CommonBase} from "../lib/forge-std/src/Base.sol"; +import {StdAssertions} from "../lib/forge-std/src/StdAssertions.sol"; +import {StdChains} from "../lib/forge-std/src/StdChains.sol"; +import {StdCheats, StdCheatsSafe} from "../lib/forge-std/src/StdCheats.sol"; +import {StdUtils} from "../lib/forge-std/src/StdUtils.sol"; +import {Test} from "../lib/forge-std/src/Test.sol"; +import {IERC3156FlashBorrower} from "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; +import {ERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/ERC20.sol"; +import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; +import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; +import {IPartyPool} from "../src/IPartyPool.sol"; +import {LMSRStabilized} from "../src/LMSRStabilized.sol"; +import {PartyPlanner} from "../src/PartyPlanner.sol"; +import {PartyPool} from "../src/PartyPool.sol"; import {Deploy} from "./Deploy.sol"; +import {TestERC20, FlashBorrower} from "./GasTest.sol"; + +/* solhint-disable erc20-unchecked-transfer */ /// @notice Test contract that implements the flash callback for testing flash loans contract FlashBorrower is IERC3156FlashBorrower { @@ -99,10 +109,10 @@ contract GasTest is Test { using SafeERC20 for TestERC20; PartyPlanner internal planner; - PartyPool internal pool2; - PartyPool internal pool10; - PartyPool internal pool20; - PartyPool internal pool50; + IPartyPool internal pool2; + IPartyPool internal pool10; + IPartyPool internal pool20; + IPartyPool internal pool50; address internal alice; address internal bob; @@ -115,7 +125,7 @@ contract GasTest is Test { uint256 constant internal BASE = 1; // use base=1 so internal amounts correspond to raw integers (Q64.64 units) /// @notice Helper function to create a pool with the specified number of _tokens - function createPool(uint256 numTokens) internal returns (PartyPool) { + function createPool(uint256 numTokens) internal returns (IPartyPool) { // Deploy _tokens dynamically address[] memory tokens = new address[](numTokens); uint256[] memory bases = new uint256[](numTokens); @@ -141,21 +151,21 @@ contract GasTest is Test { } // Compute kappa from slippage params and number of _tokens, then construct pool with kappa int128 computedKappa = LMSRStabilized.computeKappaFromSlippage(ierc20Tokens.length, tradeFrac, targetSlippage); - PartyPool newPool = Deploy.newPartyPool(address(this), poolName, poolName, ierc20Tokens, computedKappa, feePpm, feePpm, false); - // Transfer initial deposit amounts into pool before initial mint + uint256[] memory initialBalances = new uint256[](numTokens); for (uint256 i = 0; i < numTokens; i++) { - TestERC20(tokens[i]).transfer(address(newPool), INIT_BAL); + initialBalances[i] = INIT_BAL; + ierc20Tokens[i].approve(address(planner), INIT_BAL); } - - // Perform initial mint (initial deposit); receiver is this contract - newPool.initialMint(address(this), 0); + vm.prank(planner.owner()); + (IPartyPool newPool, ) = planner.newPool(poolName, poolName, ierc20Tokens, computedKappa, feePpm, feePpm, false, + address(this), address(this), initialBalances, 0, 0); return newPool; } /// @notice Helper to create a pool with the stable-pair optimization enabled - function createPoolStable(uint256 numTokens) internal returns (PartyPool) { + function createPoolStable(uint256 numTokens) internal returns (IPartyPool) { // Deploy _tokens dynamically address[] memory tokens = new address[](numTokens); uint256[] memory bases = new uint256[](numTokens); @@ -181,7 +191,7 @@ contract GasTest is Test { ierc20Tokens[i] = IERC20(tokens[i]); } int128 computedKappa = LMSRStabilized.computeKappaFromSlippage(ierc20Tokens.length, tradeFrac, targetSlippage); - PartyPool newPool = Deploy.newPartyPool(address(this), poolName, poolName, ierc20Tokens, computedKappa, feePpm, feePpm, true); + IPartyPool newPool = Deploy.newPartyPool(address(this), poolName, poolName, ierc20Tokens, computedKappa, feePpm, feePpm, true); // Transfer initial deposit amounts into pool before initial mint for (uint256 i = 0; i < numTokens; i++) { @@ -227,15 +237,37 @@ contract GasTest is Test { } /// @notice Helper function: perform 10 swaps back-and-forth between the first two _tokens. - function _performSwapGasTest(PartyPool testPool) internal { + function _performSwapGasTest(IPartyPool testPool) internal { + _performSwapGasTest(testPool, false); + } + + function sendTokensCallback(IERC20 token, uint256 amount) external { + // verify the caller + require(planner.getPoolSupported(msg.sender), 'Not a LiqP pool'); + token.transferFrom( alice, msg.sender, amount); + } + + function _performSwapGasTest(IPartyPool testPool, bool useCallback) internal { IERC20[] memory tokens = testPool.allTokens(); require(tokens.length >= 2, "Pool must have at least 2 tokens"); + address payer; + address spender; + bytes4 selector; - // Ensure alice approves pool for both _tokens + if (useCallback) { + payer = address(this); + spender = address(this); + selector = this.sendTokensCallback.selector; + } + else { + payer = alice; + spender = address(testPool); + selector = bytes4(0); + } vm.prank(alice); - TestERC20(address(tokens[0])).approve(address(testPool), type(uint256).max); + TestERC20(address(tokens[0])).approve(spender, type(uint256).max); vm.prank(alice); - TestERC20(address(tokens[1])).approve(address(testPool), type(uint256).max); + TestERC20(address(tokens[1])).approve(spender, type(uint256).max); uint256 maxIn = 10_000; @@ -244,10 +276,10 @@ contract GasTest is Test { vm.prank(alice); if (i % 2 == 0) { // swap token0 -> token1 - testPool.swap(alice, alice, 0, 1, maxIn, 0, 0, false); + testPool.swap(payer, selector, alice, 0, 1, maxIn, 0, 0, false); } else { // swap token1 -> token0 - testPool.swap(alice, alice, 1, 0, maxIn, 0, 0, false); + testPool.swap(payer, selector, alice, 1, 0, maxIn, 0, 0, false); } // shake up the bits maxIn *= 787; @@ -265,6 +297,12 @@ contract GasTest is Test { _performSwapGasTest(pool10); } + /// @notice Gas measurement: perform 10 swaps back-and-forth between first two _tokens in the 10-token pool using the callback funding method. + function testSwapGasCallback() public { + _performSwapGasTest(pool10, true); + } + + /// @notice Gas measurement: perform 10 swaps back-and-forth between first two _tokens in the 20-token pool. function testSwapGasTwenty() public { _performSwapGasTest(pool20); @@ -277,24 +315,24 @@ contract GasTest is Test { /// @notice Gas measurement: perform 10 swaps back-and-forth on a 2-token stable pair (stable-path enabled) function testSwapGasStablePair() public { - PartyPool stablePair = createPoolStable(2); + IPartyPool stablePair = createPoolStable(2); _performSwapGasTest(stablePair); } /// @notice Gas-style test: alternate swapMint then burnSwap on a 2-token stable pair function testSwapMintBurnSwapGasStablePair() public { - PartyPool stablePair = createPoolStable(2); + IPartyPool stablePair = createPoolStable(2); _performSwapMintBurnSwapGasTest(stablePair); } /// @notice Combined gas test (mint then burn) on 2-token stable pair using mint() and burn(). function testMintBurnGasStablePair() public { - PartyPool stablePair = createPoolStable(2); + IPartyPool stablePair = createPoolStable(2); _performMintBurnGasTest(stablePair); } /// @notice Helper function: alternate swapMint then burnSwap to keep pool size roughly stable. - function _performSwapMintBurnSwapGasTest(PartyPool testPool) internal { + function _performSwapMintBurnSwapGasTest(IPartyPool testPool) internal { uint256 iterations = 10; uint256 input = 1_000; IERC20[] memory tokens = testPool.allTokens(); @@ -339,7 +377,7 @@ contract GasTest is Test { /// @notice Helper function: combined gas test (mint then burn) using mint() and burn(). /// Alternates minting a tiny LP amount and immediately burning the actual minted LP back to avoid net pool depletion. - function _performMintBurnGasTest(PartyPool testPool) internal { + function _performMintBurnGasTest(IPartyPool testPool) internal { uint256 iterations = 50; uint256 input = 1_000; IERC20[] memory poolTokens = testPool.allTokens(); diff --git a/test/NativeTest.t.sol b/test/NativeTest.t.sol index 707e779..bc505e1 100644 --- a/test/NativeTest.t.sol +++ b/test/NativeTest.t.sol @@ -142,6 +142,7 @@ contract NativeTest is Test { // Send native currency with {value: maxIn} (uint256 amountIn, uint256 amountOut, ) = pool.swap{value: maxIn}( alice, // payer + bytes4(0), alice, // receiver 2, // inputTokenIndex (WETH) 0, // outputTokenIndex (token0) @@ -179,6 +180,7 @@ contract NativeTest is Test { // Execute swap: token0 (index 0) -> WETH (index 2) with unwrap=true (uint256 amountIn, uint256 amountOut, ) = pool.swap( alice, // payer + bytes4(0), // no selector: use ERC20 approvals alice, // receiver 0, // inputTokenIndex (token0) 2, // outputTokenIndex (WETH) @@ -214,6 +216,7 @@ contract NativeTest is Test { // Execute swap with excess native currency (uint256 amountIn, , ) = pool.swap{value: totalSent}( alice, // payer + bytes4(0), alice, // receiver 2, // inputTokenIndex (WETH) 0, // outputTokenIndex (token0) @@ -542,14 +545,14 @@ contract NativeTest is Test { // 2. Swap native currency for token0 uint256 swapAmount = 5_000; (, uint256 amountOut, ) = pool.swap{value: swapAmount}( - alice, alice, 2, 0, swapAmount, 0, 0, false + alice,bytes4(0),alice, 2, 0, swapAmount, 0, 0, false ); assertTrue(amountOut > 0, "Should receive token0"); // 3. Swap token0 back to native currency uint256 token0Balance = token0.balanceOf(alice); (, uint256 swapOut2, ) = pool.swap( - alice, alice, 0, 2, token0Balance / 2, 0, 0, true + alice, bytes4(0), alice, 0, 2, token0Balance / 2, 0, 0, true ); assertTrue(swapOut2 > 0, "Should receive native currency"); @@ -576,7 +579,7 @@ contract NativeTest is Test { // Swap token0 -> WETH without unwrap (, uint256 amountOut, ) = pool.swap( - alice, alice, 0, 2, maxIn, 0, 0, false // unwrap=false + alice, bytes4(0), alice, 0, 2, maxIn, 0, 0, false // unwrap=false ); assertTrue(amountOut > 0, "Should receive WETH tokens"); @@ -597,7 +600,7 @@ contract NativeTest is Test { // Try to swap token0 (not WETH) by sending native currency - should revert vm.expectRevert(); pool.swap{value: 10_000}( - alice, alice, 0, 1, 10_000, 0, 0, false + alice, bytes4(0), alice, 0, 1, 10_000, 0, 0, false ); vm.stopPrank(); diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index 28ada20..95651af 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -423,7 +423,7 @@ contract PartyPoolTest is Test { // Execute swap: token0 -> token1 vm.prank(alice); - (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swap(alice, bob, 0, 1, maxIn, 0, 0, false); + (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swap(alice, bytes4(0), bob, 0, 1, maxIn, 0, 0, false); // Amounts should be positive and not exceed provided max assertTrue(amountInUsed > 0, "expected some input used"); @@ -452,7 +452,7 @@ contract PartyPoolTest is Test { vm.prank(alice); vm.expectRevert(bytes("LMSR: limitPrice <= current price")); - pool.swap(alice, alice, 0, 1, 1000, limitPrice, 0, false); + pool.swap(alice, bytes4(0), alice, 0, 1, 1000, limitPrice, 0, false); } /// @notice swapToLimit should compute input needed to reach a slightly higher price and execute. @@ -1024,8 +1024,8 @@ contract PartyPoolTest is Test { token0.approve(address(poolCustom), type(uint256).max); // Perform identical swaps: token0 -> token1 - (uint256 amountInDefault, uint256 amountOutDefault, uint256 feeDefault) = poolDefault.swap(alice, alice, 0, 1, swapAmount, 0, 0, false); - (uint256 amountInCustom, uint256 amountOutCustom, uint256 feeCustom) = poolCustom.swap(alice, alice, 0, 1, swapAmount, 0, 0, false); + (uint256 amountInDefault, uint256 amountOutDefault, uint256 feeDefault) = poolDefault.swap(alice, bytes4(0), alice, 0, 1, swapAmount, 0, 0, false); + (uint256 amountInCustom, uint256 amountOutCustom, uint256 feeCustom) = poolCustom.swap(alice, bytes4(0), alice, 0, 1, swapAmount, 0, 0, false); // Swap results should be identical assertEq(amountInDefault, amountInCustom, "Swap input amounts should be identical");