From 96535ed005e57f7b8d31bcc8f8d88591ad70f226 Mon Sep 17 00:00:00 2001 From: tim Date: Tue, 14 Oct 2025 20:54:15 -0400 Subject: [PATCH] native currency fixes --- src/IPartyPool.sol | 19 +++++++++++-------- src/PartyPool.sol | 38 +++++++++++++++++++------------------- src/PartyPoolBase.sol | 20 ++++++++++++-------- src/PartyPoolMintImpl.sol | 16 ++++++++-------- src/PartyPoolSwapImpl.sol | 7 +++---- test/GasTest.sol | 8 ++++---- test/PartyPool.t.sol | 18 +++++++++--------- 7 files changed, 66 insertions(+), 60 deletions(-) diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index 802c054..0b8f356 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -109,7 +109,7 @@ interface IPartyPool is IERC20Metadata { /// Can only be called when the pool is uninitialized (totalSupply() == 0 or lmsr.nAssets == 0). /// @param receiver address that receives the LP tokens /// @param lpTokens The number of LP tokens to issue for this mint. If 0, then the number of tokens returned will equal the LMSR internal q total - function initialMint(address receiver, uint256 lpTokens) external returns (uint256 lpMinted); + function initialMint(address receiver, uint256 lpTokens) external payable returns (uint256 lpMinted); /// @notice Proportional mint (or initial supply if first call). /// @dev - For initial supply: assumes tokens have already been transferred to the pool prior to calling. @@ -120,16 +120,16 @@ interface IPartyPool is IERC20Metadata { /// @param lpTokenAmount desired amount of LP tokens to mint (ignored for initial deposit) /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. /// @return lpMinted the actual amount of lpToken minted - function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external returns (uint256 lpMinted); + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external payable returns (uint256 lpMinted); /// @notice Burn LP tokens and withdraw the proportional basket to receiver. - /// @dev Payer must own or approve the LP tokens being burned. The function updates LMSR state - /// proportionally to reflect the reduced pool size after the withdrawal. + /// @dev This function forwards the call to the burn implementation via delegatecall /// @param payer address that provides the LP tokens to burn /// @param receiver address that receives the withdrawn tokens /// @param lpAmount amount of LP tokens to burn (proportional withdrawal) /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline) external returns (uint256[] memory withdrawAmounts); + /// @param unwrap if true and the native token is being withdrawn, it is unwraped and sent as native currency + function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline, bool unwrap) external returns (uint256[] memory withdrawAmounts); // Swaps @@ -165,7 +165,8 @@ interface IPartyPool is IERC20Metadata { uint256 outputTokenIndex, uint256 maxAmountIn, int128 limitPrice, - uint256 deadline + uint256 deadline, + bool unwrap ) external payable returns (uint256 amountIn, uint256 amountOut, uint256 fee); /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. @@ -184,7 +185,8 @@ interface IPartyPool is IERC20Metadata { uint256 inputTokenIndex, uint256 outputTokenIndex, int128 limitPrice, - uint256 deadline + uint256 deadline, + bool unwrap ) external payable returns (uint256 amountInUsed, uint256 amountOut, uint256 fee); /// @notice Single-token mint: deposit a single token, charge swap-LMSR cost, and mint LP. @@ -217,7 +219,8 @@ interface IPartyPool is IERC20Metadata { address receiver, uint256 lpAmount, uint256 inputTokenIndex, - uint256 deadline + uint256 deadline, + bool unwrap ) external returns (uint256 amountOutUint); /** diff --git a/src/PartyPool.sol b/src/PartyPool.sol index 3192db9..877cf8b 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -151,7 +151,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { ---------------------- */ /// @inheritdoc IPartyPool - function initialMint(address receiver, uint256 lpTokens) external + function initialMint(address receiver, uint256 lpTokens) external payable returns (uint256 lpMinted) { bytes memory data = abi.encodeWithSelector( PartyPoolMintImpl.initialMint.selector, @@ -169,7 +169,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { /// @param receiver address that receives the LP tokens /// @param lpTokenAmount desired amount of LP tokens to mint /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external payable returns (uint256 lpMinted) { bytes memory data = abi.encodeWithSelector( PartyPoolMintImpl.mint.selector, @@ -182,20 +182,16 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { return abi.decode(result, (uint256)); } - /// @notice Burn LP tokens and withdraw the proportional basket to receiver. - /// @dev This function forwards the call to the burn implementation via delegatecall - /// @param payer address that provides the LP tokens to burn - /// @param receiver address that receives the withdrawn tokens - /// @param lpAmount amount of LP tokens to burn (proportional withdrawal) - /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline) external + /// @inheritdoc IPartyPool + function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline, bool unwrap) external returns (uint256[] memory withdrawAmounts) { bytes memory data = abi.encodeWithSelector( PartyPoolMintImpl.burn.selector, payer, receiver, lpAmount, - deadline + deadline, + unwrap ); bytes memory result = Address.functionDelegateCall(address(MINT_IMPL), data); return abi.decode(result, (uint256[])); @@ -223,8 +219,9 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 outputTokenIndex, uint256 maxAmountIn, int128 limitPrice, - uint256 deadline - ) external payable nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + uint256 deadline, + bool unwrap + ) external payable native nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { require(deadline == 0 || block.timestamp <= deadline, "swap: deadline exceeded"); // Compute amounts using the same path as views @@ -243,7 +240,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 balJAfter = cachedUintBalances[outputTokenIndex] + protocolFeesOwed[outputTokenIndex] - amountOutUint; // Transfer output to receiver via centralized helper - _sendTokenTo(tokenOut, receiver, amountOutUint); + _sendTokenTo(tokenOut, receiver, amountOutUint, unwrap); // Accrue protocol share (floor) from the fee on input token if (PROTOCOL_FEE_PPM > 0 && feeUint > 0) { @@ -265,8 +262,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { emit Swap(payer, receiver, tokenIn, tokenOut, totalTransferAmount, amountOutUint); - _refund(); - return (totalTransferAmount, amountOutUint, feeUint); } @@ -302,6 +297,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { require(deltaInternalI > int128(0), "swap: input too small after fee"); // Compute internal amounts using LMSR (exact-input with price limit) + // use the virtual method call so that the balanced pair optimization can override (amountInInternalUsed, amountOutInternal) = _swapAmountsForExactInput(inputTokenIndex, outputTokenIndex, deltaInternalI, limitPrice); // Convert actual used input internal -> uint (ceil) @@ -330,7 +326,8 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 inputTokenIndex, uint256 outputTokenIndex, int128 limitPrice, - uint256 deadline + uint256 deadline, + bool unwrap ) external payable returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { bytes memory data = abi.encodeWithSelector( PartyPoolSwapImpl.swapToLimit.selector, @@ -340,6 +337,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { outputTokenIndex, limitPrice, deadline, + unwrap, SWAP_FEE_PPM, PROTOCOL_FEE_PPM ); @@ -391,7 +389,8 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { address receiver, uint256 lpAmount, uint256 inputTokenIndex, - uint256 deadline + uint256 deadline, + bool unwrap ) external returns (uint256 amountOutUint) { bytes memory data = abi.encodeWithSelector( PartyPoolMintImpl.burnSwap.selector, @@ -400,6 +399,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { lpAmount, inputTokenIndex, deadline, + unwrap, SWAP_FEE_PPM, PROTOCOL_FEE_PPM ); @@ -438,7 +438,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { } } - _sendTokenTo(token, address(receiver), amount); + _sendTokenTo(token, address(receiver), amount, false); require(receiver.onFlashLoan(msg.sender, address(token), amount, fee, data) == FLASH_CALLBACK_SUCCESS); _receiveTokenFrom(address(receiver), token, amount + fee); @@ -466,7 +466,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { require(bal >= owed, "collect: fee > bal"); protocolFeesOwed[i] = 0; // transfer owed tokens to protocol destination via centralized helper - _sendTokenTo(tokens[i], dest, owed); + _sendTokenTo(tokens[i], dest, owed, false); // update cached to effective onchain minus owed cachedUintBalances[i] = bal - owed; } diff --git a/src/PartyPoolBase.sol b/src/PartyPoolBase.sol index 20ce59b..3fc7ffb 100644 --- a/src/PartyPoolBase.sol +++ b/src/PartyPoolBase.sol @@ -23,6 +23,15 @@ abstract contract PartyPoolBase is ERC20Internal, ReentrancyGuard, PartyPoolHelp WRAPPER_TOKEN = wrapper_; } + /// @notice Designates methods that can receive native currency. + /// @dev If the pool has any balance of native currency at the end of the method, it is refunded to msg.sender + modifier native() { + _; + uint256 bal = address(this).balance; + if(bal > 0) + payable(msg.sender).transfer(bal); + } + // // Internal state (no immutables here; immutables belong to derived contracts) // @@ -104,19 +113,14 @@ abstract contract PartyPoolBase is ERC20Internal, ReentrancyGuard, PartyPoolHelp /// @notice Send tokens from the pool to `receiver` using SafeERC20 semantics. /// @dev Note: this helper does NOT query the on-chain balance after transfer to save gas. /// Callers should query the balance themselves when they need it (e.g., to detect fee-on-transfer tokens). - function _sendTokenTo(IERC20 token, address receiver, uint256 amount) internal { - if( token == WRAPPER_TOKEN ) { + function _sendTokenTo(IERC20 token, address receiver, uint256 amount, bool unwrap) internal { + if( unwrap && token == WRAPPER_TOKEN ) { WRAPPER_TOKEN.withdraw(amount); (bool ok, ) = receiver.call{value: amount}(""); - require(ok); // todo make unwrapping optional + require(ok, 'receiver not payable'); } else token.safeTransfer(receiver, amount); } - function _refund() internal { - uint256 bal = address(this).balance; - if(bal > 0) - payable(msg.sender).transfer(bal); - } } diff --git a/src/PartyPoolMintImpl.sol b/src/PartyPoolMintImpl.sol index 861c782..1ff3310 100644 --- a/src/PartyPoolMintImpl.sol +++ b/src/PartyPoolMintImpl.sol @@ -25,7 +25,7 @@ contract PartyPoolMintImpl is PartyPoolBase { // Initialization Mint // - function initialMint(address receiver, uint256 lpTokens, int128 KAPPA) external nonReentrant + function initialMint(address receiver, uint256 lpTokens, int128 KAPPA) external payable native nonReentrant returns (uint256 lpMinted) { uint256 n = tokens.length; @@ -65,7 +65,7 @@ contract PartyPoolMintImpl is PartyPoolBase { // Regular Mint and Burn // - function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external payable nonReentrant + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external payable native nonReentrant returns (uint256 lpMinted) { require(deadline == 0 || block.timestamp <= deadline, "mint: deadline exceeded"); uint256 n = tokens.length; @@ -128,8 +128,6 @@ contract PartyPoolMintImpl is PartyPoolBase { _mint(receiver, actualLpToMint); emit IPartyPool.Mint(payer, receiver, depositAmounts, actualLpToMint); - _refund(); - return actualLpToMint; } @@ -140,7 +138,8 @@ contract PartyPoolMintImpl is PartyPoolBase { /// @param receiver address that receives the withdrawn tokens /// @param lpAmount amount of LP tokens to burn (proportional withdrawal) /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline) external nonReentrant + /// @param unwrap if true and the native token is being withdrawn, it is unwraped and sent as native currency + function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline, bool unwrap) external nonReentrant returns (uint256[] memory withdrawAmounts) { require(deadline == 0 || block.timestamp <= deadline, "burn: deadline exceeded"); uint256 n = tokens.length; @@ -157,7 +156,7 @@ contract PartyPoolMintImpl is PartyPoolBase { // Transfer underlying tokens out to receiver according to computed proportions for (uint i = 0; i < n; ) { if (withdrawAmounts[i] > 0) { - _sendTokenTo(tokens[i], receiver, withdrawAmounts[i]); + _sendTokenTo(tokens[i], receiver, withdrawAmounts[i], unwrap); } unchecked { i++; } } @@ -346,7 +345,7 @@ contract PartyPoolMintImpl is PartyPoolBase { uint256 deadline, uint256 swapFeePpm, uint256 protocolFeePpm - ) external nonReentrant returns (uint256 lpMinted) { + ) external payable native nonReentrant returns (uint256 lpMinted) { uint256 n = tokens.length; require(inputTokenIndex < n, "swapMint: idx"); require(maxAmountIn > 0, "swapMint: input zero"); @@ -483,6 +482,7 @@ contract PartyPoolMintImpl is PartyPoolBase { uint256 lpAmount, uint256 inputTokenIndex, uint256 deadline, + bool unwrap, uint256 swapFeePpm, uint256 protocolFeePpm ) external nonReentrant returns (uint256 amountOutUint) { @@ -521,7 +521,7 @@ contract PartyPoolMintImpl is PartyPoolBase { } // Transfer the payout to receiver via centralized helper - _sendTokenTo(tokens[inputTokenIndex], receiver, amountOutUint); + _sendTokenTo(tokens[inputTokenIndex], receiver, amountOutUint, unwrap); // Burn LP tokens from payer (authorization via allowance) if (msg.sender != payer) { diff --git a/src/PartyPoolSwapImpl.sol b/src/PartyPoolSwapImpl.sol index de5818c..73105c6 100644 --- a/src/PartyPoolSwapImpl.sol +++ b/src/PartyPoolSwapImpl.sol @@ -56,9 +56,10 @@ contract PartyPoolSwapImpl is PartyPoolBase { uint256 outputTokenIndex, int128 limitPrice, uint256 deadline, + bool unwrap, uint256 swapFeePpm, uint256 protocolFeePpm - ) external payable returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { + ) external payable native returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { uint256 n = tokens.length; require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); require(limitPrice > int128(0), "swapToLimit: limit <= 0"); @@ -78,7 +79,7 @@ contract PartyPoolSwapImpl is PartyPoolBase { require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); // Transfer output to receiver and verify exact decrease - _sendTokenTo(tokens[outputTokenIndex], receiver, amountOutUint); + _sendTokenTo(tokens[outputTokenIndex], receiver, amountOutUint, unwrap); uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); @@ -103,8 +104,6 @@ contract PartyPoolSwapImpl is PartyPoolBase { // Maintain original event semantics (logs input without fee) emit IPartyPool.Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], amountInUsedUint, amountOutUint); - _refund(); - return (amountInUsedUint, amountOutUint, feeUint); } diff --git a/test/GasTest.sol b/test/GasTest.sol index 20fa5fc..c43d589 100644 --- a/test/GasTest.sol +++ b/test/GasTest.sol @@ -244,10 +244,10 @@ contract GasTest is Test { vm.prank(alice); if (i % 2 == 0) { // swap token0 -> token1 - testPool.swap(alice, alice, 0, 1, maxIn, 0, 0); + testPool.swap(alice, alice, 0, 1, maxIn, 0, 0, false); } else { // swap token1 -> token0 - testPool.swap(alice, alice, 1, 0, maxIn, 0, 0); + testPool.swap(alice, alice, 1, 0, maxIn, 0, 0, false); } } } @@ -308,7 +308,7 @@ contract GasTest is Test { // If nothing minted (numerical edge), skip burn step if (minted == 0) continue; // Immediately burn the minted LP back to tokens, targeting the same token index - testPool.burnSwap(alice, alice, minted, 0, 0); + testPool.burnSwap(alice, alice, minted, 0, 0, false); } vm.stopPrank(); @@ -368,7 +368,7 @@ contract GasTest is Test { } // Burn via plain burn() which will transfer underlying back to alice and burn LP - testPool.burn(alice, alice, actualMinted, 0); + testPool.burn(alice, alice, actualMinted, 0, false); } vm.stopPrank(); diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index a97c9b7..03e5e6c 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -408,7 +408,7 @@ contract PartyPoolTest is Test { // Burn by sending LP tokens from this contract (which holds initial LP from setUp) // Call burn(payer=this, receiver=bob, lpAmount=totalLp) - pool.burn(address(this), bob, totalLp, 0); + pool.burn(address(this), bob, totalLp, 0, false); // After burning entire pool, totalSupply should be zero or very small (we expect zero since we withdrew all) assertEq(pool.totalSupply(), 0); @@ -434,7 +434,7 @@ contract PartyPoolTest is Test { // Execute swap: token0 -> token1 vm.prank(alice); - (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swap(alice, bob, 0, 1, maxIn, 0, 0); + (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swap(alice, bob, 0, 1, maxIn, 0, 0, false); // Amounts should be positive and not exceed provided max assertTrue(amountInUsed > 0, "expected some input used"); @@ -463,7 +463,7 @@ contract PartyPoolTest is Test { vm.prank(alice); vm.expectRevert(bytes("LMSR: limitPrice <= current price")); - pool.swap(alice, alice, 0, 1, 1000, limitPrice, 0); + pool.swap(alice, alice, 0, 1, 1000, limitPrice, 0, false); } /// @notice swapToLimit should compute input needed to reach a slightly higher price and execute. @@ -475,7 +475,7 @@ contract PartyPoolTest is Test { token0.approve(address(pool), type(uint256).max); vm.prank(alice); - (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swapToLimit(alice, bob, 0, 1, limitPrice, 0); + (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swapToLimit(alice, bob, 0, 1, limitPrice, 0, false); assertTrue(amountInUsed > 0, "expected some input used for swapToLimit"); assertTrue(amountOut > 0, "expected some output for swapToLimit"); @@ -637,7 +637,7 @@ contract PartyPoolTest is Test { uint256 b2Before = token2.balanceOf(bob); // Perform burn using the computed LP amount (proportional withdrawal) - pool.burn(address(this), bob, req, 0); + pool.burn(address(this), bob, req, 0, false); // Verify bob received exactly the expected amounts assertEq(token0.balanceOf(bob) - b0Before, expected[0], "token0 withdraw mismatch"); @@ -701,7 +701,7 @@ contract PartyPoolTest is Test { beforeBal[8] = token8.balanceOf(bob); beforeBal[9] = token9.balanceOf(bob); - pool10.burn(address(this), bob, req, 0); + pool10.burn(address(this), bob, req, 0, false); // Verify bob received each expected amount assertEq(token0.balanceOf(bob) - beforeBal[0], expected[0], "t0 withdraw mismatch"); @@ -795,7 +795,7 @@ contract PartyPoolTest is Test { uint256 bobBefore = token0.balanceOf(bob); // Call burnSwap where this contract is the payer (it holds initial LP from setUp) - uint256 payout = pool.burnSwap(address(this), bob, lpToBurn, target, 0); + uint256 payout = pool.burnSwap(address(this), bob, lpToBurn, target, 0, false); // Payout must be > 0 assertTrue(payout > 0, "burnSwap should produce a payout"); @@ -1040,8 +1040,8 @@ contract PartyPoolTest is Test { token0.approve(address(poolCustom), type(uint256).max); // Perform identical swaps: token0 -> token1 - (uint256 amountInDefault, uint256 amountOutDefault, uint256 feeDefault) = poolDefault.swap(alice, alice, 0, 1, swapAmount, 0, 0); - (uint256 amountInCustom, uint256 amountOutCustom, uint256 feeCustom) = poolCustom.swap(alice, alice, 0, 1, swapAmount, 0, 0); + (uint256 amountInDefault, uint256 amountOutDefault, uint256 feeDefault) = poolDefault.swap(alice, alice, 0, 1, swapAmount, 0, 0, false); + (uint256 amountInCustom, uint256 amountOutCustom, uint256 feeCustom) = poolCustom.swap(alice, alice, 0, 1, swapAmount, 0, 0, false); // Swap results should be identical assertEq(amountInDefault, amountInCustom, "Swap input amounts should be identical");