From 20af14c8727943a6475b775944aff22ca174bb6e Mon Sep 17 00:00:00 2001 From: tim Date: Mon, 6 Oct 2025 17:38:47 -0400 Subject: [PATCH] proxied swapToLimit --- script/DeployMock.sol | 2 +- src/IPartyPlanner.sol | 18 +-- src/LMSRStabilized.sol | 3 - src/PartyPlanner.sol | 27 ++-- src/PartyPool.sol | 280 ++++++++++++++------------------------ src/PartyPoolSwapImpl.sol | 123 +++++++++++++++++ test/PartyPlanner.t.sol | 44 +++--- 7 files changed, 273 insertions(+), 224 deletions(-) diff --git a/script/DeployMock.sol b/script/DeployMock.sol index 4a8d497..e54d6e7 100644 --- a/script/DeployMock.sol +++ b/script/DeployMock.sol @@ -58,7 +58,7 @@ contract DeployMock is Script { } // call full newPool signature on factory which will take the deposits and mint initial LP - (PartyPool pool, ) = planner.newPool( + (IPartyPool pool, ) = planner.newPool( name, symbol, tokens, diff --git a/src/IPartyPlanner.sol b/src/IPartyPlanner.sol index 64e777b..653488d 100644 --- a/src/IPartyPlanner.sol +++ b/src/IPartyPlanner.sol @@ -1,14 +1,16 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; -import "./PartyPool.sol"; +import "./IPartyPool.sol"; +import "./PartyPoolMintImpl.sol"; +import "./PartyPoolSwapImpl.sol"; import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; /// @title IPartyPlanner /// @notice Interface for factory contract for creating and tracking PartyPool instances interface IPartyPlanner { // Event emitted when a new pool is created - event PartyStarted(PartyPool indexed pool, string name, string symbol, IERC20[] tokens); + event PartyStarted(IPartyPool indexed pool, string name, string symbol, IERC20[] tokens); /// @notice Creates a new PartyPool instance and initializes it with initial deposits (legacy signature). /// @dev Deprecated in favour of the kappa-based overload below; kept for backwards compatibility. @@ -44,7 +46,7 @@ interface IPartyPlanner { uint256[] memory initialDeposits, uint256 initialLpAmount, uint256 deadline - ) external returns (PartyPool pool, uint256 lpAmount); + ) external returns (IPartyPool pool, uint256 lpAmount); /// @notice Creates a new PartyPool instance and initializes it with initial deposits (kappa-based). /// @param name_ LP token name @@ -77,7 +79,7 @@ interface IPartyPlanner { uint256[] memory initialDeposits, uint256 initialLpAmount, uint256 deadline - ) external returns (PartyPool pool, uint256 lpAmount); + ) external returns (IPartyPool pool, uint256 lpAmount); /// @notice Checks if a pool is supported /// @param pool The pool address to check @@ -92,7 +94,7 @@ interface IPartyPlanner { /// @param offset Starting index for pagination /// @param limit Maximum number of items to return /// @return pools Array of pool addresses for the requested page - function getAllPools(uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools); + function getAllPools(uint256 offset, uint256 limit) external view returns (IPartyPool[] memory pools); /// @notice Returns the total number of unique tokens /// @return The total count of unique tokens @@ -114,12 +116,12 @@ interface IPartyPlanner { /// @param offset Starting index for pagination /// @param limit Maximum number of items to return /// @return pools Array of pool addresses containing the specified token - function getPoolsByToken(IERC20 token, uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools); + function getPoolsByToken(IERC20 token, uint256 offset, uint256 limit) external view returns (IPartyPool[] memory pools); - /// @notice Address of the SwapMint implementation contract used by all pools created by this factory + /// @notice Address of the mint implementation contract used by all pools created by this factory function mintImpl() external view returns (PartyPoolMintImpl); - /// @notice Address of the SwapMint implementation contract used by all pools created by this factory + /// @notice Address of the swap implementation contract used by all pools created by this factory function swapMintImpl() external view returns (PartyPoolSwapImpl); } diff --git a/src/LMSRStabilized.sol b/src/LMSRStabilized.sol index f2227d3..f27e132 100644 --- a/src/LMSRStabilized.sol +++ b/src/LMSRStabilized.sol @@ -719,9 +719,6 @@ library LMSRStabilized { require(newTotal > int128(0), "LMSR: new total zero"); - // With kappa formulation, b automatically scales with pool size - int128 newB = s.kappa.mul(newTotal); - // Update the cached qInternal with new values for (uint i = 0; i < s.nAssets; ) { s.qInternal[i] = newQInternal[i]; diff --git a/src/PartyPlanner.sol b/src/PartyPlanner.sol index 9c399e7..eccc362 100644 --- a/src/PartyPlanner.sol +++ b/src/PartyPlanner.sol @@ -5,6 +5,7 @@ import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20 import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; import {IPartyPlanner} from "./IPartyPlanner.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol"; +import {IPartyPool} from "./IPartyPool.sol"; import {PartyPool} from "./PartyPool.sol"; import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol"; import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol"; @@ -33,11 +34,11 @@ contract PartyPlanner is IPartyPlanner { function protocolFeeAddress() external view returns (address) { return PROTOCOL_FEE_ADDRESS; } // On-chain pool indexing - PartyPool[] private _allPools; + IPartyPool[] private _allPools; IERC20[] private _allTokens; - mapping(PartyPool => bool) private _poolSupported; + mapping(IPartyPool => bool) private _poolSupported; mapping(IERC20 => bool) private _tokenSupported; - mapping(IERC20 => PartyPool[]) private _poolsByToken; + mapping(IERC20 => IPartyPool[]) private _poolsByToken; /// @param _swapMintImpl address of the SwapMint implementation contract to be used by all pools /// @param _mintImpl address of the Mint implementation contract to be used by all pools @@ -76,7 +77,7 @@ contract PartyPlanner is IPartyPlanner { uint256[] memory initialDeposits, uint256 initialLpAmount, uint256 deadline - ) public returns (PartyPool pool, uint256 lpAmount) { + ) public returns (IPartyPool pool, uint256 lpAmount) { // Validate inputs require(deadline == 0 || block.timestamp <= deadline, "Planner: deadline exceeded"); require(_tokens.length == initialDeposits.length, "Planner: tokens and deposits length mismatch"); @@ -165,7 +166,7 @@ contract PartyPlanner is IPartyPlanner { uint256[] memory initialDeposits, uint256 initialLpAmount, uint256 deadline - ) external returns (PartyPool pool, uint256 lpAmount) { + ) external returns (IPartyPool pool, uint256 lpAmount) { // Validate fixed-point fractions: must be less than 1.0 in 64.64 fixed-point require(_tradeFrac < ONE, "Planner: tradeFrac must be < 1 (64.64)"); require(_targetSlippage < ONE, "Planner: targetSlippage must be < 1 (64.64)"); @@ -193,7 +194,7 @@ contract PartyPlanner is IPartyPlanner { /// @inheritdoc IPartyPlanner function getPoolSupported(address pool) external view returns (bool) { - return _poolSupported[PartyPool(pool)]; + return _poolSupported[IPartyPool(pool)]; } /// @inheritdoc IPartyPlanner @@ -202,19 +203,19 @@ contract PartyPlanner is IPartyPlanner { } /// @inheritdoc IPartyPlanner - function getAllPools(uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools) { + function getAllPools(uint256 offset, uint256 limit) external view returns (IPartyPool[] memory pools) { uint256 totalPools = _allPools.length; // If offset is beyond array bounds, return empty array if (offset >= totalPools) { - return new PartyPool[](0); + return new IPartyPool[](0); } // Calculate actual number of pools to return (respecting bounds) uint256 itemsToReturn = (offset + limit > totalPools) ? (totalPools - offset) : limit; // Create result array of appropriate size - pools = new PartyPool[](itemsToReturn); + pools = new IPartyPool[](itemsToReturn); // Fill the result array for (uint256 i = 0; i < itemsToReturn; i++) { @@ -258,20 +259,20 @@ contract PartyPlanner is IPartyPlanner { } /// @inheritdoc IPartyPlanner - function getPoolsByToken(IERC20 token, uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools) { - PartyPool[] storage tokenPools = _poolsByToken[token]; + function getPoolsByToken(IERC20 token, uint256 offset, uint256 limit) external view returns (IPartyPool[] memory pools) { + IPartyPool[] storage tokenPools = _poolsByToken[token]; uint256 totalPools = tokenPools.length; // If offset is beyond array bounds, return empty array if (offset >= totalPools) { - return new PartyPool[](0); + return new IPartyPool[](0); } // Calculate actual number of pools to return (respecting bounds) uint256 itemsToReturn = (offset + limit > totalPools) ? (totalPools - offset) : limit; // Create result array of appropriate size - pools = new PartyPool[](itemsToReturn); + pools = new IPartyPool[](itemsToReturn); // Fill the result array for (uint256 i = 0; i < itemsToReturn; i++) { diff --git a/src/PartyPool.sol b/src/PartyPool.sol index e46b8b8..7059ab4 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -134,6 +134,47 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { } + // + // Current marginal prices + // + + /// @notice Marginal price of `base` in terms of `quote` (p_quote / p_base) as Q64.64 + /// @dev Returns the LMSR marginal price directly (raw 64.64) for use by off-chain quoting logic. + function price(uint256 baseTokenIndex, uint256 quoteTokenIndex) external view returns (int128) { + uint256 n = tokens.length; + require(baseTokenIndex < n && quoteTokenIndex < n, "price: idx"); + require(lmsr.nAssets > 0, "price: uninit"); + return lmsr.price(baseTokenIndex, quoteTokenIndex); + } + + /// @notice Price of one LP token denominated in `quote` asset as Q64.64 + /// @dev Computes LMSR poolPrice (quote per unit qTotal) and scales it by totalSupply() / qTotal + /// to return price per LP token unit in quote asset (raw 64.64). + function poolPrice(uint256 quoteTokenIndex) external view returns (int128) { + uint256 n = tokens.length; + require(quoteTokenIndex < n, "poolPrice: idx"); + require(lmsr.nAssets > 0, "poolPrice: uninit"); + + // price per unit of qTotal (Q64.64) from LMSR + int128 pricePerQ = lmsr.poolPrice(quoteTokenIndex); + + // total internal q (qTotal) as Q64.64 + int128 qTotal = _computeSizeMetric(lmsr.qInternal); + require(qTotal > int128(0), "poolPrice: qTotal zero"); + + // totalSupply as Q64.64 + uint256 supply = totalSupply(); + require(supply > 0, "poolPrice: zero supply"); + int128 supplyQ64 = ABDKMath64x64.fromUInt(supply); + + // factor = totalSupply / qTotal (Q64.64) + int128 factor = supplyQ64.div(qTotal); + + // price per LP token = pricePerQ * factor (Q64.64) + return pricePerQ.mul(factor); + } + + /* ---------------------- Initialization / Mint / Burn (LP token managed) ---------------------- */ @@ -215,48 +256,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { } /// @inheritdoc IPartyPool - function swapToLimitAmounts( - uint256 inputTokenIndex, - uint256 outputTokenIndex, - int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { - (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice); - return (grossIn, outUint, feeUint); - } - - - /// @notice Transfer all protocol fees to the configured protocolFeeAddress and zero the ledger. - /// @dev Anyone can call; must have protocolFeeAddress != address(0) to be operational. - function collectProtocolFees() external nonReentrant { - address dest = PROTOCOL_FEE_ADDRESS; - require(dest != address(0), "collect: zero addr"); - - uint256 n = tokens.length; - for (uint256 i = 0; i < n; i++) { - uint256 owed = protocolFeesOwed[i]; - if (owed == 0) continue; - uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); - require(bal >= owed, "collect: fee > bal"); - protocolFeesOwed[i] = 0; - // transfer owed tokens to protocol destination - tokens[i].safeTransfer(dest, owed); - // update cached to effective onchain minus owed - cachedUintBalances[i] = bal - owed; - } - } - - - /// @notice Swap input token i -> token j. Payer must approve token i. - /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. - /// Non-standard tokens (fee-on-transfer, rebasers) are rejected via balance checks. - /// @param payer address of the account that pays for the swap - /// @param receiver address that will receive the output tokens - /// @param inputTokenIndex index of input asset - /// @param outputTokenIndex index of output asset - /// @param maxAmountIn maximum amount of token i (uint256) to transfer in (inclusive of fees) - /// @param limitPrice maximum acceptable marginal price (64.64 fixed point). Pass 0 to ignore. - /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256), fee fee taken from the input (uint256) function swap( address payer, address receiver, @@ -309,62 +308,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { return (totalTransferAmount, amountOutUint, feeUint); } - /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. - /// @dev If balances prevent fully reaching the limit, the function caps and returns actuals. - /// The payer must transfer the exact gross input computed by the view. - /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function swapToLimit( - address payer, - address receiver, - uint256 inputTokenIndex, - uint256 outputTokenIndex, - int128 limitPrice, - uint256 deadline - ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { - uint256 n = tokens.length; - require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); - require(limitPrice > int128(0), "swapToLimit: limit <= 0"); - require(deadline == 0 || block.timestamp <= deadline, "swapToLimit: deadline exceeded"); - - // Read previous balances for affected assets - uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - uint256 prevBalJ = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); - - // Compute amounts using the same path as views - (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalMax, int128 amountOutInternal, uint256 amountInUsedUint, uint256 feeUint) = - _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice); - - // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) - tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransferAmount); - uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); - - // Transfer output to receiver and verify exact decrease - tokens[outputTokenIndex].safeTransfer(receiver, amountOutUint); - uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); - require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); - - // Accrue protocol share (floor) from the fee on input token - if (PROTOCOL_FEE_PPM > 0 && feeUint > 0 && PROTOCOL_FEE_ADDRESS != address(0)) { - uint256 protoShare = (feeUint * PROTOCOL_FEE_PPM) / 1_000_000; // floor - if (protoShare > 0) { - protocolFeesOwed[inputTokenIndex] += protoShare; - } - } - - // Update caches to effective balances - _recordCachedBalance(inputTokenIndex, balIAfter); - _recordCachedBalance(outputTokenIndex, balJAfter); - - // Apply swap to LMSR state with the internal amounts - lmsr.applySwap(inputTokenIndex, outputTokenIndex, amountInInternalMax, amountOutInternal); - - // Maintain original event semantics (logs input without fee) - emit Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], amountInUsedUint, amountOutUint); - - return (amountInUsedUint, amountOutUint, feeUint); - } - /// @notice Internal quote for exact-input swap that mirrors swap() rounding and fee application /// @dev Returns amounts consistent with swap() semantics: grossIn includes fees (ceil), amountOut is floored. /// @return grossIn amount to transfer in (inclusive of fee), amountOutUint output amount (uint), @@ -423,65 +366,45 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { require(amountOutUint > 0, "swap: output zero"); } - /// @notice Internal quote for swap-to-limit that mirrors swapToLimit() rounding and fee application - /// @dev Computes the input required to reach limitPrice and the resulting output; all rounding matches swapToLimit. - /// @return grossIn amount to transfer in (inclusive of fee), amountOutUint output amount (uint), - /// amountInInternal and amountOutInternal (64.64), amountInUintNoFee input amount excluding fee (uint), - /// feeUint fee taken from the gross input (uint) - function _quoteSwapToLimit( + + /// @inheritdoc IPartyPool + function swapToLimitAmounts( uint256 inputTokenIndex, uint256 outputTokenIndex, int128 limitPrice - ) - internal - view - returns ( - uint256 grossIn, - uint256 amountOutUint, - int128 amountInInternal, - int128 amountOutInternal, - uint256 amountInUintNoFee, - uint256 feeUint - ) - { - uint256 n = tokens.length; - require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + require(inputTokenIndex < tokens.length && outputTokenIndex < tokens.length, "swapToLimit: idx"); require(limitPrice > int128(0), "swapToLimit: limit <= 0"); require(lmsr.nAssets > 0, "swapToLimit: pool uninitialized"); - // Compute internal maxima at the price limit - (amountInInternal, amountOutInternal) = lmsr.swapAmountsForPriceLimit(inputTokenIndex, outputTokenIndex, limitPrice); - - // Convert input to uint (ceil) and output to uint (floor) - amountInUintNoFee = _internalToUintCeil(amountInInternal, bases[inputTokenIndex]); - require(amountInUintNoFee > 0, "swapToLimit: input zero"); - - feeUint = 0; - grossIn = amountInUintNoFee; - if (SWAP_FEE_PPM > 0) { - feeUint = _ceilFee(amountInUintNoFee, SWAP_FEE_PPM); - grossIn += feeUint; - } - - amountOutUint = _internalToUintFloor(amountOutInternal, bases[outputTokenIndex]); - require(amountOutUint > 0, "swapToLimit: output zero"); + return SWAP_IMPL.swapToLimitAmounts( + inputTokenIndex, outputTokenIndex, limitPrice, + bases, KAPPA, lmsr.qInternal, SWAP_FEE_PPM); } - /// @notice Compute fee and net amounts for a gross input (fee rounded up to favor the pool). - /// @return feeUint fee taken (uint) and netUint remaining for protocol use (uint) - function _computeFee(uint256 gross) internal view returns (uint256 feeUint, uint256 netUint) { - if (SWAP_FEE_PPM == 0) { - return (0, gross); - } - feeUint = _ceilFee(gross, SWAP_FEE_PPM); - netUint = gross - feeUint; - } - /// @notice Convenience: return gross = net + fee(net) using ceiling for fee. - function _addFee(uint256 netUint) internal view returns (uint256 gross) { - if (SWAP_FEE_PPM == 0) return netUint; - uint256 fee = _ceilFee(netUint, SWAP_FEE_PPM); - return netUint + fee; + /// @inheritdoc IPartyPool + function swapToLimit( + address payer, + address receiver, + uint256 inputTokenIndex, + uint256 outputTokenIndex, + int128 limitPrice, + uint256 deadline + ) external nonReentrant returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { + bytes memory data = abi.encodeWithSignature( + 'swapToLimit(address,address,uint256,uint256,int128,uint256,uint256,uint256)', + payer, + receiver, + inputTokenIndex, + outputTokenIndex, + limitPrice, + deadline, + SWAP_FEE_PPM, + PROTOCOL_FEE_PPM + ); + bytes memory result = Address.functionDelegateCall(address(SWAP_IMPL), data); + return abi.decode(result, (uint256,uint256,uint256)); } function swapMintAmounts(uint256 inputTokenIndex, uint256 maxAmountIn) external view @@ -661,40 +584,24 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { } - /// @notice Marginal price of `base` in terms of `quote` (p_quote / p_base) as Q64.64 - /// @dev Returns the LMSR marginal price directly (raw 64.64) for use by off-chain quoting logic. - function price(uint256 baseTokenIndex, uint256 quoteTokenIndex) external view returns (int128) { + /// @notice Transfer all protocol fees to the configured protocolFeeAddress and zero the ledger. + /// @dev Anyone can call; must have protocolFeeAddress != address(0) to be operational. + function collectProtocolFees() external nonReentrant { + address dest = PROTOCOL_FEE_ADDRESS; + require(dest != address(0), "collect: zero addr"); + uint256 n = tokens.length; - require(baseTokenIndex < n && quoteTokenIndex < n, "price: idx"); - require(lmsr.nAssets > 0, "price: uninit"); - return lmsr.price(baseTokenIndex, quoteTokenIndex); - } - - /// @notice Price of one LP token denominated in `quote` asset as Q64.64 - /// @dev Computes LMSR poolPrice (quote per unit qTotal) and scales it by totalSupply() / qTotal - /// to return price per LP token unit in quote asset (raw 64.64). - function poolPrice(uint256 quoteTokenIndex) external view returns (int128) { - uint256 n = tokens.length; - require(quoteTokenIndex < n, "poolPrice: idx"); - require(lmsr.nAssets > 0, "poolPrice: uninit"); - - // price per unit of qTotal (Q64.64) from LMSR - int128 pricePerQ = lmsr.poolPrice(quoteTokenIndex); - - // total internal q (qTotal) as Q64.64 - int128 qTotal = _computeSizeMetric(lmsr.qInternal); - require(qTotal > int128(0), "poolPrice: qTotal zero"); - - // totalSupply as Q64.64 - uint256 supply = totalSupply(); - require(supply > 0, "poolPrice: zero supply"); - int128 supplyQ64 = ABDKMath64x64.fromUInt(supply); - - // factor = totalSupply / qTotal (Q64.64) - int128 factor = supplyQ64.div(qTotal); - - // price per LP token = pricePerQ * factor (Q64.64) - return pricePerQ.mul(factor); + for (uint256 i = 0; i < n; i++) { + uint256 owed = protocolFeesOwed[i]; + if (owed == 0) continue; + uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); + require(bal >= owed, "collect: fee > bal"); + protocolFeesOwed[i] = 0; + // transfer owed tokens to protocol destination + tokens[i].safeTransfer(dest, owed); + // update cached to effective onchain minus owed + cachedUintBalances[i] = bal - owed; + } } @@ -703,4 +610,15 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { return lmsr.swapAmountsForExactInput(i, j, a, limitPrice); } + /// @notice Compute fee and net amounts for a gross input (fee rounded up to favor the pool). + /// @return feeUint fee taken (uint) and netUint remaining for protocol use (uint) + function _computeFee(uint256 gross) internal view returns (uint256 feeUint, uint256 netUint) { + return _computeFee(gross, SWAP_FEE_PPM); + } + + /// @notice Convenience: return gross = net + fee(net) using ceiling for fee. + function _addFee(uint256 netUint) internal view returns (uint256 gross) { + return _addFee(netUint, SWAP_FEE_PPM); + } + } diff --git a/src/PartyPoolSwapImpl.sol b/src/PartyPoolSwapImpl.sol index b6c5422..38c4126 100644 --- a/src/PartyPoolSwapImpl.sol +++ b/src/PartyPoolSwapImpl.sol @@ -6,6 +6,7 @@ import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20 import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol"; import {PartyPoolBase} from "./PartyPoolBase.sol"; +import {IPartyPool} from "./IPartyPool.sol"; /// @title PartyPoolSwapMintImpl - Implementation contract for swapMint and burnSwap functions /// @notice This contract contains the swapMint and burnSwap implementation that will be called via delegatecall @@ -15,5 +16,127 @@ contract PartyPoolSwapImpl is PartyPoolBase { using LMSRStabilized for LMSRStabilized.State; using SafeERC20 for IERC20; + function swapToLimitAmounts( + uint256 inputTokenIndex, + uint256 outputTokenIndex, + int128 limitPrice, + uint256[] memory bases, + int128 kappa, + int128[] memory qInternal, + uint256 swapFeePpm + ) external pure returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + // Compute internal maxima at the price limit + (int128 amountInInternal, int128 amountOutInternal) = LMSRStabilized.swapAmountsForPriceLimit( + bases.length, kappa, qInternal, + inputTokenIndex, outputTokenIndex, limitPrice); + + // Convert input to uint (ceil) and output to uint (floor) + uint256 amountInUintNoFee = _internalToUintCeil(amountInInternal, bases[inputTokenIndex]); + require(amountInUintNoFee > 0, "swapToLimit: input zero"); + + fee = 0; + amountIn = amountInUintNoFee; + if (swapFeePpm > 0) { + fee = _ceilFee(amountInUintNoFee, swapFeePpm); + amountIn += fee; + } + + amountOut = _internalToUintFloor(amountOutInternal, bases[outputTokenIndex]); + require(amountOut > 0, "swapToLimit: output zero"); + } + + + function swapToLimit( + address payer, + address receiver, + uint256 inputTokenIndex, + uint256 outputTokenIndex, + int128 limitPrice, + uint256 deadline, + uint256 swapFeePpm, + uint256 protocolFeePpm + ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { + uint256 n = tokens.length; + require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); + require(limitPrice > int128(0), "swapToLimit: limit <= 0"); + require(deadline == 0 || block.timestamp <= deadline, "swapToLimit: deadline exceeded"); + + // Read previous balances for affected assets + uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); + uint256 prevBalJ = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); + + // Compute amounts using the same path as views + (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalMax, int128 amountOutInternal, uint256 amountInUsedUint, uint256 feeUint) = + _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice, swapFeePpm); + + // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) + tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransferAmount); + uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); + require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); + + // Transfer output to receiver and verify exact decrease + tokens[outputTokenIndex].safeTransfer(receiver, amountOutUint); + uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); + require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); + + // Accrue protocol share (floor) from the fee on input token + if (protocolFeePpm > 0 && feeUint > 0 ) { + uint256 protoShare = (feeUint * protocolFeePpm) / 1_000_000; // floor + if (protoShare > 0) { + protocolFeesOwed[inputTokenIndex] += protoShare; + } + } + + // Update caches to effective balances + _recordCachedBalance(inputTokenIndex, balIAfter); + _recordCachedBalance(outputTokenIndex, balJAfter); + + // Apply swap to LMSR state with the internal amounts + lmsr.applySwap(inputTokenIndex, outputTokenIndex, amountInInternalMax, amountOutInternal); + + // Maintain original event semantics (logs input without fee) + emit IPartyPool.Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], amountInUsedUint, amountOutUint); + + return (amountInUsedUint, amountOutUint, feeUint); + } + + + /// @notice Internal quote for swap-to-limit that mirrors swapToLimit() rounding and fee application + /// @dev Computes the input required to reach limitPrice and the resulting output; all rounding matches swapToLimit. + /// @return grossIn amount to transfer in (inclusive of fee), amountOutUint output amount (uint), + /// amountInInternal and amountOutInternal (64.64), amountInUintNoFee input amount excluding fee (uint), + /// feeUint fee taken from the gross input (uint) + function _quoteSwapToLimit( + uint256 inputTokenIndex, + uint256 outputTokenIndex, + int128 limitPrice, + uint256 swapFeePpm + ) internal view + returns ( + uint256 grossIn, + uint256 amountOutUint, + int128 amountInInternal, + int128 amountOutInternal, + uint256 amountInUintNoFee, + uint256 feeUint + ) + { + // Compute internal maxima at the price limit + (amountInInternal, amountOutInternal) = lmsr.swapAmountsForPriceLimit(inputTokenIndex, outputTokenIndex, limitPrice); + + // Convert input to uint (ceil) and output to uint (floor) + amountInUintNoFee = _internalToUintCeil(amountInInternal, bases[inputTokenIndex]); + require(amountInUintNoFee > 0, "swapToLimit: input zero"); + + feeUint = 0; + grossIn = amountInUintNoFee; + if (swapFeePpm > 0) { + feeUint = _ceilFee(amountInUintNoFee, swapFeePpm); + grossIn += feeUint; + } + + amountOutUint = _internalToUintFloor(amountOutInternal, bases[outputTokenIndex]); + require(amountOutUint > 0, "swapToLimit: output zero"); + } } diff --git a/test/PartyPlanner.t.sol b/test/PartyPlanner.t.sol index a1ed953..a3ffb96 100644 --- a/test/PartyPlanner.t.sol +++ b/test/PartyPlanner.t.sol @@ -1,12 +1,20 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; -import "forge-std/Test.sol"; -import "../src/LMSRStabilized.sol"; -import "../src/PartyPlanner.sol"; -import "../src/PartyPool.sol"; +import {CommonBase} from "../lib/forge-std/src/Base.sol"; +import {StdAssertions} from "../lib/forge-std/src/StdAssertions.sol"; +import {StdChains} from "../lib/forge-std/src/StdChains.sol"; +import {StdCheats, StdCheatsSafe} from "../lib/forge-std/src/StdCheats.sol"; +import {StdUtils} from "../lib/forge-std/src/StdUtils.sol"; +import {Test} from "../lib/forge-std/src/Test.sol"; +import {ERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/ERC20.sol"; +import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {Deploy} from "../src/Deploy.sol"; -import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; +import {IPartyPool} from "../src/IPartyPool.sol"; +import {LMSRStabilized} from "../src/LMSRStabilized.sol"; +import {PartyPlanner} from "../src/PartyPlanner.sol"; +import {PartyPool} from "../src/PartyPool.sol"; +import {MockERC20} from "./PartyPlanner.t.sol"; // Mock ERC20 token for testing contract MockERC20 is ERC20 { @@ -88,7 +96,7 @@ contract PartyPlannerTest is Test { // Compute kappa then create pool via kappa overload int128 computedKappa = LMSRStabilized.computeKappaFromSlippage(tokens.length, tradeFrac, targetSlippage); - (PartyPool pool, uint256 lpAmount) = planner.newPool( + (IPartyPool pool, uint256 lpAmount) = planner.newPool( name, symbol, tokens, @@ -117,7 +125,7 @@ contract PartyPlannerTest is Test { assertEq(planner.poolsByTokenCount(IERC20(address(tokenB))), initialTokenBCount + 1, "TokenB pool count should increase"); // Verify pools can be retrieved - PartyPool[] memory allPools = planner.getAllPools(0, 10); + IPartyPool[] memory allPools = planner.getAllPools(0, 10); bool poolFound = false; for (uint256 i = 0; i < allPools.length; i++) { if (allPools[i] == pool) { @@ -128,7 +136,7 @@ contract PartyPlannerTest is Test { assertTrue(poolFound, "Created pool should be in getAllPools result"); // Verify pool appears in token-specific queries - PartyPool[] memory tokenAPools = planner.getPoolsByToken(IERC20(address(tokenA)), 0, 10); + IPartyPool[] memory tokenAPools = planner.getPoolsByToken(IERC20(address(tokenA)), 0, 10); bool poolInTokenA = false; for (uint256 i = 0; i < tokenAPools.length; i++) { if (tokenAPools[i] == pool) { @@ -138,7 +146,7 @@ contract PartyPlannerTest is Test { } assertTrue(poolInTokenA, "Pool should be indexed under tokenA"); - PartyPool[] memory tokenBPools = planner.getPoolsByToken(IERC20(address(tokenB)), 0, 10); + IPartyPool[] memory tokenBPools = planner.getPoolsByToken(IERC20(address(tokenB)), 0, 10); bool poolInTokenB = false; for (uint256 i = 0; i < tokenBPools.length; i++) { if (tokenBPools[i] == pool) { @@ -167,7 +175,7 @@ contract PartyPlannerTest is Test { deposits1[1] = INITIAL_DEPOSIT_AMOUNT; int128 kappa1 = LMSRStabilized.computeKappaFromSlippage(tokens1.length, int128((1 << 64) - 1), int128(1 << 62)); - (PartyPool pool1,) = planner.newPool( + (IPartyPool pool1,) = planner.newPool( "Pool 1", "LP1", tokens1, bases1, kappa1, 3000, 5000, false, payer, receiver, deposits1, 1000e18, 0 @@ -187,7 +195,7 @@ contract PartyPlannerTest is Test { deposits2[1] = INITIAL_DEPOSIT_AMOUNT / 1e12; // Adjust for 6 decimals int128 kappa2 = LMSRStabilized.computeKappaFromSlippage(tokens2.length, int128((1 << 64) - 1), int128(1 << 62)); - (PartyPool pool2,) = planner.newPool( + (IPartyPool pool2,) = planner.newPool( "Pool 2", "LP2", tokens2, bases2, kappa2, 3000, 5000, false, payer, receiver, deposits2, 1000e18, 0 @@ -203,7 +211,7 @@ contract PartyPlannerTest is Test { assertEq(planner.poolsByTokenCount(IERC20(address(tokenC))), 1, "TokenC should be in 1 pool"); // Verify tokenB appears in both pools - PartyPool[] memory tokenBPools = planner.getPoolsByToken(IERC20(address(tokenB)), 0, 10); + IPartyPool[] memory tokenBPools = planner.getPoolsByToken(IERC20(address(tokenB)), 0, 10); assertEq(tokenBPools.length, 2, "TokenB should have 2 pools"); bool pool1Found = false; @@ -274,7 +282,7 @@ contract PartyPlannerTest is Test { function test_poolIndexing_Pagination() public { // Create multiple pools for pagination testing uint256 numPools = 5; - PartyPool[] memory createdPools = new PartyPool[](numPools); + IPartyPool[] memory createdPools = new IPartyPool[](numPools); for (uint256 i = 0; i < numPools; i++) { IERC20[] memory tokens = new IERC20[](2); @@ -290,7 +298,7 @@ contract PartyPlannerTest is Test { deposits[1] = INITIAL_DEPOSIT_AMOUNT; int128 kappaLoop = LMSRStabilized.computeKappaFromSlippage(tokens.length, int128((1 << 64) - 1), int128(1 << 62)); - (PartyPool pool,) = planner.newPool( + (IPartyPool pool,) = planner.newPool( string(abi.encodePacked("Pool ", vm.toString(i))), string(abi.encodePacked("LP", vm.toString(i))), tokens, bases, @@ -304,19 +312,19 @@ contract PartyPlannerTest is Test { assertEq(planner.poolCount(), numPools, "Should have created all pools"); // Test pagination - get first 3 pools - PartyPool[] memory page1 = planner.getAllPools(0, 3); + IPartyPool[] memory page1 = planner.getAllPools(0, 3); assertEq(page1.length, 3, "First page should have 3 pools"); // Test pagination - get next 2 pools - PartyPool[] memory page2 = planner.getAllPools(3, 3); + IPartyPool[] memory page2 = planner.getAllPools(3, 3); assertEq(page2.length, 2, "Second page should have 2 pools"); // Test pagination - offset beyond bounds - PartyPool[] memory emptyPage = planner.getAllPools(10, 3); + IPartyPool[] memory emptyPage = planner.getAllPools(10, 3); assertEq(emptyPage.length, 0, "Should return empty array for out of bounds offset"); // Verify all pools are accessible through pagination - PartyPool[] memory allPools = planner.getAllPools(0, 10); + IPartyPool[] memory allPools = planner.getAllPools(0, 10); assertEq(allPools.length, numPools, "Should return all pools"); for (uint256 i = 0; i < numPools; i++) {