From fef6d007d8d784d39ee080359123e645ad98d563 Mon Sep 17 00:00:00 2001 From: tim Date: Mon, 15 Sep 2025 18:51:47 -0400 Subject: [PATCH] return fee taken --- src/IPartyPool.sol | 8 ++++---- src/PartyPool.sol | 46 ++++++++++++++++++++++++++------------------ test/PartyPool.t.sol | 8 ++++++-- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index 7966e89..df612eb 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -93,7 +93,7 @@ interface IPartyPool is IERC20Metadata { uint256 j, uint256 maxAmountIn, int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut); + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee); function swap( address payer, @@ -103,14 +103,14 @@ interface IPartyPool is IERC20Metadata { uint256 maxAmountIn, int128 limitPrice, uint256 deadline - ) external returns (uint256 amountIn, uint256 amountOut); + ) external returns (uint256 amountIn, uint256 amountOut, uint256 fee); /// @notice External view to quote swap-to-limit amounts (gross input incl. fee and output), matching swapToLimit() computations function swapToLimitAmounts( uint256 i, uint256 j, int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut); + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee); function swapToLimit( address payer, @@ -119,7 +119,7 @@ interface IPartyPool is IERC20Metadata { uint256 j, int128 limitPrice, uint256 deadline - ) external returns (uint256 amountInUsed, uint256 amountOut); + ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee); /// @notice Single-token mint: deposit a single token, charge swap-LMSR cost, and mint LP. /// @param payer who transfers the input token diff --git a/src/PartyPool.sol b/src/PartyPool.sol index 12b68c9..f3761fe 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -336,7 +336,8 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { /// @notice Internal quote for exact-input swap that mirrors swap() rounding and fee application /// @return grossIn amount to transfer in (inclusive of fee), amountOutUint output amount (uint), - /// amountInInternalUsed and amountOutInternal (64.64), amountInUintNoFee input amount excluding fee (uint) + /// amountInInternalUsed and amountOutInternal (64.64), amountInUintNoFee input amount excluding fee (uint), + /// feeUint fee taken from the gross input (uint) function _quoteSwapExactIn( uint256 i, uint256 j, @@ -350,7 +351,8 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 amountOutUint, int128 amountInInternalUsed, int128 amountOutInternal, - uint256 amountInUintNoFee + uint256 amountInUintNoFee, + uint256 feeUint ) { uint256 n = tokens.length; @@ -373,9 +375,11 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { require(amountInUintNoFee > 0, "swap: input zero"); // Compute gross transfer including fee on the used input (ceil) + feeUint = 0; grossIn = amountInUintNoFee; if (swapFeePpm > 0) { - grossIn += _ceilFee(amountInUintNoFee, swapFeePpm); + feeUint = _ceilFee(amountInUintNoFee, swapFeePpm); + grossIn += feeUint; } // Ensure within user max @@ -388,7 +392,8 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { /// @notice Internal quote for swap-to-limit that mirrors swapToLimit() rounding and fee application /// @return grossIn amount to transfer in (inclusive of fee), amountOutUint output amount (uint), - /// amountInInternal and amountOutInternal (64.64), amountInUintNoFee input amount excluding fee (uint) + /// amountInInternal and amountOutInternal (64.64), amountInUintNoFee input amount excluding fee (uint), + /// feeUint fee taken from the gross input (uint) function _quoteSwapToLimit( uint256 i, uint256 j, @@ -401,7 +406,8 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 amountOutUint, int128 amountInInternal, int128 amountOutInternal, - uint256 amountInUintNoFee + uint256 amountInUintNoFee, + uint256 feeUint ) { uint256 n = tokens.length; @@ -416,9 +422,11 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { amountInUintNoFee = _internalToUintCeil(amountInInternal, bases[i]); require(amountInUintNoFee > 0, "swapToLimit: input zero"); + feeUint = 0; grossIn = amountInUintNoFee; if (swapFeePpm > 0) { - grossIn += _ceilFee(amountInUintNoFee, swapFeePpm); + feeUint = _ceilFee(amountInUintNoFee, swapFeePpm); + grossIn += feeUint; } amountOutUint = _internalToUintFloor(amountOutInternal, bases[j]); @@ -431,9 +439,9 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 j, uint256 maxAmountIn, int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut) { - (uint256 grossIn, uint256 outUint,,,) = _quoteSwapExactIn(i, j, maxAmountIn, limitPrice); - return (grossIn, outUint); + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(i, j, maxAmountIn, limitPrice); + return (grossIn, outUint, feeUint); } /// @notice External view to quote swap-to-limit amounts (gross input incl. fee and output), matching swapToLimit() computations @@ -441,9 +449,9 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 i, uint256 j, int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut) { - (uint256 grossIn, uint256 outUint,,,) = _quoteSwapToLimit(i, j, limitPrice); - return (grossIn, outUint); + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapToLimit(i, j, limitPrice); + return (grossIn, outUint, feeUint); } @@ -455,7 +463,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { /// @param maxAmountIn maximum amount of token i (uint256) to transfer in (inclusive of fees) /// @param limitPrice maximum acceptable marginal price (64.64 fixed point). Pass 0 to ignore. /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256) + /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256), fee fee taken from the input (uint256) function swap( address payer, address receiver, @@ -464,7 +472,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 maxAmountIn, int128 limitPrice, uint256 deadline - ) external nonReentrant returns (uint256 amountIn, uint256 amountOut) { + ) external nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { uint256 n = tokens.length; require(i < n && j < n, "swap: idx"); require(maxAmountIn > 0, "swap: input zero"); @@ -475,7 +483,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 prevBalJ = IERC20(tokens[j]).balanceOf(address(this)); // Compute amounts using the same path as views - (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalUsed, int128 amountOutInternal, ) = + (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalUsed, int128 amountOutInternal, , uint256 feeUint) = _quoteSwapExactIn(i, j, maxAmountIn, limitPrice); // Transfer the exact amount from payer and require exact receipt (revert on fee-on-transfer) @@ -497,7 +505,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { emit Swap(payer, receiver, tokens[i], tokens[j], totalTransferAmount, amountOutUint); - return (totalTransferAmount, amountOutUint); + return (totalTransferAmount, amountOutUint, feeUint); } /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. @@ -511,7 +519,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 j, int128 limitPrice, uint256 deadline - ) external returns (uint256 amountInUsed, uint256 amountOut) { + ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { uint256 n = tokens.length; require(i < n && j < n, "swapToLimit: idx"); require(limitPrice > int128(0), "swapToLimit: limit <= 0"); @@ -522,7 +530,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { uint256 prevBalJ = IERC20(tokens[j]).balanceOf(address(this)); // Compute amounts using the same path as views - (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalMax, int128 amountOutInternal, uint256 amountInUsedUint) = + (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalMax, int128 amountOutInternal, uint256 amountInUsedUint, uint256 feeUint) = _quoteSwapToLimit(i, j, limitPrice); // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) @@ -545,7 +553,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { // Maintain original event semantics (logs input without fee) emit Swap(payer, receiver, tokens[i], tokens[j], amountInUsedUint, amountOutUint); - return (amountInUsedUint, amountOutUint); + return (amountInUsedUint, amountOutUint, feeUint); } /// @notice Ceiling fee helper: computes ceil(x * feePpm / 1_000_000) diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index 6f28d7e..0f48715 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -453,12 +453,14 @@ contract PartyPoolTest is Test { // Execute swap: token0 -> token1 vm.prank(alice); - (uint256 amountInUsed, uint256 amountOut) = pool.swap(alice, bob, 0, 1, maxIn, 0, 0); + (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swap(alice, bob, 0, 1, maxIn, 0, 0); // Amounts should be positive and not exceed provided max assertTrue(amountInUsed > 0, "expected some input used"); assertTrue(amountOut > 0, "expected some output returned"); assertTrue(amountInUsed <= maxIn, "used input must not exceed max"); + // Fee should be <= amountInUsed + assertTrue(fee <= amountInUsed, "fee must not exceed total input"); // Alice's balance decreased by exactly amountInUsed assertEq(token0.balanceOf(alice), balAliceBefore - amountInUsed); @@ -492,10 +494,12 @@ contract PartyPoolTest is Test { token0.approve(address(pool), type(uint256).max); vm.prank(alice); - (uint256 amountInUsed, uint256 amountOut) = pool.swapToLimit(alice, bob, 0, 1, limitPrice, 0); + (uint256 amountInUsed, uint256 amountOut, uint256 fee) = pool.swapToLimit(alice, bob, 0, 1, limitPrice, 0); assertTrue(amountInUsed > 0, "expected some input used for swapToLimit"); assertTrue(amountOut > 0, "expected some output for swapToLimit"); + // Fee should be <= amountInUsed (gross includes fee) + assertTrue(fee <= amountInUsed, "fee must not exceed total input for swapToLimit"); // Verify bob got the output assertEq(token1.balanceOf(bob) >= amountOut, true);