diff --git a/.gitignore b/.gitignore index 51f60db..7ab0110 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ -# Compiler files cache/ out/ +chain.json + docs/ log/ .env diff --git a/bin/mock b/bin/mock index da22baf..dfa1b5d 100755 --- a/bin/mock +++ b/bin/mock @@ -10,6 +10,11 @@ cleanup() { kill $ANVIL_PID 2>/dev/null } +err() { + cleanup + exit $1 +} + # Set up trap to handle script exit trap cleanup EXIT @@ -38,7 +43,7 @@ while ! check_string "Listening on" "log/anvil.txt"; do fi done -forge script --code-size-limit ${CODE_SIZE_LIMIT} --private-key ${PRIVATE_KEY} DeployMock --fork-url http://localhost:8545 --broadcast "$@" +forge script --code-size-limit ${CODE_SIZE_LIMIT} --private-key ${PRIVATE_KEY} DeployMock --fork-url http://localhost:8545 --broadcast "$@" || err 1 echo "Press Ctrl+C to exit..." while true; do diff --git a/foundry.toml b/foundry.toml index bad1645..9ce8e42 100644 --- a/foundry.toml +++ b/foundry.toml @@ -9,7 +9,8 @@ remappings = [ optimizer=true optimizer_runs=999999999 viaIR=true -gas_reports = ['PartyPool'] +gas_reports = ['PartyPool', 'PartyPlanner'] +fs_permissions = [{ access = "write", path = "chain.json"}] [lint] exclude_lints=['mixed-case-variable', 'unaliased-plain-import', ] diff --git a/script/DeployMock.sol b/script/DeployMock.sol index 3d47536..5caf1f9 100644 --- a/script/DeployMock.sol +++ b/script/DeployMock.sol @@ -7,6 +7,7 @@ import "@abdk/ABDKMath64x64.sol"; import "../test/MockERC20.sol"; import "../src/IPartyPool.sol"; import "../src/PartyPool.sol"; +import "../src/PartyPlanner.sol"; contract DeployMock is Script { @@ -24,10 +25,10 @@ contract DeployMock is Script { string memory name = 'Mock Pool'; string memory symbol = 'MP'; - address[] memory tokens = new address[](3); - tokens[0] = address(usxd); - tokens[1] = address(fusd); - tokens[2] = address(dive); + IERC20[] memory tokens = new IERC20[](3); + tokens[0] = IERC20(usxd); + tokens[1] = IERC20(fusd); + tokens[2] = IERC20(dive); uint256[] memory _bases = new uint256[](3); _bases[0] = 10**6; _bases[1] = 10**6; @@ -36,21 +37,68 @@ contract DeployMock is Script { int128 _targetSlippage = ABDKMath64x64.divu(1,10000); uint256 _feePpm = 100; - IPartyPool pool = new PartyPool(name, symbol, tokens, _bases, _tradeFrac, _targetSlippage, _feePpm, _feePpm, false); + // deploy a PartyPlanner factory and create the pool via factory + PartyPlanner planner = new PartyPlanner(); - // initial mint - mintAll(address(pool), 10_000); - pool.mint(devAccount7, devAccount7, 0, 0); + // prepare initial deposits (10_000 units of each token, scaled by bases) + uint256[] memory initialDeposits = new uint256[](3); + initialDeposits[0] = _bases[0] * 10_000; + initialDeposits[1] = _bases[1] * 10_000; + initialDeposits[2] = _bases[2] * 10_000; + uint256 initialLpAmount = 0; + uint256 deadline = 0; - // give tokens to dev7 + // mint tokens to the deployer so it can fund the initial deposits and approve the factory + mintAll(msg.sender, 10_000); + + // approve factory to move initial deposits + for (uint i = 0; i < tokens.length; i++) { + IERC20(tokens[i]).approve(address(planner), initialDeposits[i]); + } + + // call full createPool signature on factory which will take the deposits and mint initial LP + (PartyPool pool, uint256 lpAmount) = planner.createPool( + name, + symbol, + tokens, + _bases, + _tradeFrac, + _targetSlippage, + _feePpm, + _feePpm, + false, + msg.sender, // payer: this script + devAccount7, // receiver of initial LP + initialDeposits, + initialLpAmount, + deadline + ); + + // give tokens to dev7 for later use mintAll(devAccount7, 1_000_000); vm.stopBroadcast(); - console2.log('\nPartyPool', address(pool)); - console2.log(' USXD', address(usxd)); - console2.log(' FUSD', address(fusd)); - console2.log(' DIVE', address(dive)); + // Set ENV vars + string memory plannerStr = vm.toString(address(planner)); + vm.setEnv('PLANNER', plannerStr); + vm.setEnv('POOL', vm.toString(address(pool))); + vm.setEnv('USXD', vm.toString(address(usxd))); + vm.setEnv('FUSD', vm.toString(address(fusd))); + vm.setEnv('DIVE', vm.toString(address(dive))); + + // Write JSON config file + string memory config = 'config'; + string memory chainConfig = 'chain config'; + string memory chainConfigStr = vm.serializeString(chainConfig, 'PartyPlannerV1', plannerStr); + string memory configStr = vm.serializeString(config, vm.toString(block.chainid), chainConfigStr); + vm.writeJson(configStr, 'chain.json'); + + console2.log('\nPartyPlanner', address(planner)); + console2.log(' PartyPool', address(pool)); + console2.log(' USXD', address(usxd)); + console2.log(' FUSD', address(fusd)); + console2.log(' DIVE', address(dive)); } MockERC20 private usxd; diff --git a/src/IPartyPlanner.sol b/src/IPartyPlanner.sol new file mode 100644 index 0000000..68716d0 --- /dev/null +++ b/src/IPartyPlanner.sol @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "./PartyPool.sol"; +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +/// @title IPartyPlanner +/// @notice Interface for factory contract for creating and tracking PartyPool instances +interface IPartyPlanner { + // Event emitted when a new pool is created + event PartyStarted(PartyPool indexed pool, string name, string symbol, IERC20[] tokens); + + /// @notice Creates a new PartyPool instance and initializes it with initial deposits + /// @param name_ LP token name + /// @param symbol_ LP token symbol + /// @param _tokens token addresses (n) + /// @param _bases scaling bases for each token (n) - used when converting to/from internal 64.64 amounts + /// @param _tradeFrac trade fraction in 64.64 fixed-point (as used by LMSR) + /// @param _targetSlippage target slippage in 64.64 fixed-point (as used by LMSR) + /// @param _swapFeePpm fee in parts-per-million, taken from swap input amounts before LMSR calculations + /// @param _flashFeePpm fee in parts-per-million, taken for flash loans + /// @param _stable if true and assets.length==2, then the optimization for 2-asset stablecoin pools is activated + /// @param payer address that provides the initial token deposits + /// @param receiver address that receives the minted LP tokens + /// @param initialDeposits amounts of each token to deposit initially + /// @param deadline Reverts if nonzero and the current blocktime is later than the deadline + /// @return pool Address of the newly created and initialized PartyPool + /// @return lpAmount Amount of LP tokens minted to the receiver + function createPool( + // Pool constructor args + string memory name_, + string memory symbol_, + IERC20[] memory _tokens, + uint256[] memory _bases, + int128 _tradeFrac, + int128 _targetSlippage, + uint256 _swapFeePpm, + uint256 _flashFeePpm, + bool _stable, + // Initial deposit information + address payer, + address receiver, + uint256[] memory initialDeposits, + uint256 initialLpAmount, + uint256 deadline + ) external returns (PartyPool pool, uint256 lpAmount); + + /// @notice Checks if a pool is supported + /// @param pool The pool address to check + /// @return bool True if the pool is supported, false otherwise + function getPoolSupported(address pool) external view returns (bool); + + /// @notice Returns the total number of pools created + /// @return The total count of pools + function poolCount() external view returns (uint256); + + /// @notice Retrieves a page of pool addresses + /// @param offset Starting index for pagination + /// @param limit Maximum number of items to return + /// @return pools Array of pool addresses for the requested page + function getAllPools(uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools); + + /// @notice Returns the total number of unique tokens + /// @return The total count of unique tokens + function tokenCount() external view returns (uint256); + + /// @notice Retrieves a page of token addresses + /// @param offset Starting index for pagination + /// @param limit Maximum number of items to return + /// @return tokens Array of token addresses for the requested page + function getAllTokens(uint256 offset, uint256 limit) external view returns (address[] memory tokens); + + /// @notice Returns the total number of pools for a specific token + /// @param token The token address to query + /// @return The total count of pools containing the token + function poolsByTokenCount(IERC20 token) external view returns (uint256); + + /// @notice Retrieves a page of pool addresses for a specific token + /// @param token The token address to query pools for + /// @param offset Starting index for pagination + /// @param limit Maximum number of items to return + /// @return pools Array of pool addresses containing the specified token + function getPoolsByToken(IERC20 token, uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools); +} diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index f3e4199..7d27eda 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -28,8 +28,8 @@ interface IPartyPool is IERC20Metadata { event Swap( address payer, address indexed receiver, - address indexed tokenIn, - address indexed tokenOut, + IERC20 indexed tokenIn, + IERC20 indexed tokenOut, uint256 amountIn, uint256 amountOut ); @@ -58,13 +58,13 @@ interface IPartyPool is IERC20Metadata { // Immutable pool configuration (public getters) /// @notice Token addresses comprising the pool. Effectively immutable after construction. /// @dev tokens[i] corresponds to the i-th asset and maps to index i in the internal LMSR arrays. - function tokens(uint256) external view returns (address); // get single token + function tokens(uint256) external view returns (IERC20); // get single token /// @notice Returns the number of tokens (n) in the pool. function numTokens() external view returns (uint256); /// @notice Returns the list of all token addresses in the pool (copy). - function allTokens() external view returns (address[] memory); + function allTokens() external view returns (IERC20[] memory); /// @notice Per-token uint base denominators used to convert uint token amounts <-> internal Q64.64 representation. /// @dev denominators()[i] is the base for tokens[i]. These bases are chosen by deployer and must match token decimals. @@ -85,7 +85,7 @@ interface IPartyPool is IERC20Metadata { /// @notice Mapping from token address => (index+1). A zero value indicates the token is not in the pool. /// @dev Use index = tokenAddressToIndexPlusOne[token] - 1 when non-zero. - function tokenAddressToIndexPlusOne(address) external view returns (uint); + function tokenAddressToIndexPlusOne(IERC20) external view returns (uint); // Initialization / Mint / Burn (LP token managed) @@ -105,7 +105,8 @@ interface IPartyPool is IERC20Metadata { /// @param receiver address that receives the LP tokens /// @param lpTokenAmount desired amount of LP tokens to mint (ignored for initial deposit) /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external; + /// @return lpMinted the actual amount of lpToken minted + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external returns (uint256 lpMinted); /// @notice Calculate the proportional withdrawal amounts for a given LP token amount /// @dev Returns the maximum token amounts (rounded down) that will be withdrawn when burning lpTokenAmount. diff --git a/src/LMSRStabilizedBalancedPair.sol b/src/LMSRStabilizedBalancedPair.sol index 6b91b4b..e89fa74 100644 --- a/src/LMSRStabilizedBalancedPair.sol +++ b/src/LMSRStabilizedBalancedPair.sol @@ -183,9 +183,6 @@ library LMSRStabilizedBalancedPair { // Now compute a two-tier approximation using Horner-style evaluation to reduce mul/divs. // Primary tier (cheap quadratic): accurate for small u = a/b. // Secondary tier (cubic correction): used when u is moderate but still within U_MAX. - int128 one = ONE; - int128 HALF = ABDKMath64x64.divu(1, 2); // 0.5 - int128 THIRD = ABDKMath64x64.divu(1, 3); // ~0.333... // Precomputed thresholds int128 U_TIER1 = ABDKMath64x64.divu(1, 10); // 0.1 -> cheap quadratic tier @@ -194,7 +191,7 @@ library LMSRStabilizedBalancedPair { // u is already computed above // Compute X = u*(1 + delta) - u^2/2 int128 u2 = u.mul(u); - int128 X = u.mul(one.add(delta)).sub(u2.div(ABDKMath64x64.fromUInt(2))); + int128 X = u.mul(ONE.add(delta)).sub(u2.div(ABDKMath64x64.fromUInt(2))); // Compute X^2 once int128 X2 = X.mul(X); diff --git a/src/PartyPlanner.sol b/src/PartyPlanner.sol new file mode 100644 index 0000000..a37b0ef --- /dev/null +++ b/src/PartyPlanner.sol @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "./IPartyPlanner.sol"; +import "./PartyPool.sol"; +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; + +/// @title PartyPlanner +/// @notice Factory contract for creating and tracking PartyPool instances +contract PartyPlanner is IPartyPlanner { + using SafeERC20 for IERC20; + int128 private constant FIXED_ONE_64x64 = int128(1) << 64; + + // On-chain pool indexing + PartyPool[] private _allPools; + IERC20[] private _allTokens; + mapping(PartyPool => bool) private _poolSupported; + mapping(IERC20 => bool) private _tokenSupported; + mapping(IERC20 => PartyPool[]) private _poolsByToken; + + /// @inheritdoc IPartyPlanner + function createPool( + // Pool constructor args + string memory name_, + string memory symbol_, + IERC20[] memory _tokens, + uint256[] memory _bases, + int128 _tradeFrac, + int128 _targetSlippage, + uint256 _swapFeePpm, + uint256 _flashFeePpm, + bool _stable, + // Initial deposit information + address payer, + address receiver, + uint256[] memory initialDeposits, + uint256 initialLpAmount, + uint256 deadline + ) external returns (PartyPool pool, uint256 lpAmount) { + // Validate inputs + require(deadline == 0 || block.timestamp <= deadline, "Planner: deadline exceeded"); + require(_tokens.length == initialDeposits.length, "Planner: tokens and deposits length mismatch"); + require(payer != address(0), "Planner: payer cannot be zero address"); + require(receiver != address(0), "Planner: receiver cannot be zero address"); + + // Validate fixed-point fractions: must be less than 1.0 in 64.64 fixed-point + require(_tradeFrac < FIXED_ONE_64x64, "Planner: tradeFrac must be < 1 (64.64)"); + require(_targetSlippage < FIXED_ONE_64x64, "Planner: targetSlippage must be < 1 (64.64)"); + + // Create a new PartyPool instance + pool = new PartyPool( + name_, + symbol_, + _tokens, + _bases, + _tradeFrac, + _targetSlippage, + _swapFeePpm, + _flashFeePpm, + _stable + ); + + _allPools.push(pool); + _poolSupported[pool] = true; + + // Track tokens and populate mappings + for (uint256 i = 0; i < _tokens.length; i++) { + IERC20 token = _tokens[i]; + + // Add token to _allTokens if not already present + if (!_tokenSupported[token]) { + _allTokens.push(token); + _tokenSupported[token] = true; + } + + // Add pool to _poolsByToken mapping + _poolsByToken[token].push(pool); + } + + emit PartyStarted(pool, name_, symbol_, _tokens); + + // Transfer initial tokens from payer to the pool + for (uint256 i = 0; i < _tokens.length; i++) { + if (initialDeposits[i] > 0) { + IERC20(_tokens[i]).safeTransferFrom(payer, address(pool), initialDeposits[i]); + } + } + + // Call mint on the new pool to initialize it with the transferred tokens + lpAmount = pool.initialMint(receiver, initialLpAmount); + } + + /// @inheritdoc IPartyPlanner + function getPoolSupported(address pool) external view returns (bool) { + return _poolSupported[PartyPool(pool)]; + } + + /// @inheritdoc IPartyPlanner + function poolCount() external view returns (uint256) { + return _allPools.length; + } + + /// @inheritdoc IPartyPlanner + function getAllPools(uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools) { + uint256 totalPools = _allPools.length; + + // If offset is beyond array bounds, return empty array + if (offset >= totalPools) { + return new PartyPool[](0); + } + + // Calculate actual number of pools to return (respecting bounds) + uint256 itemsToReturn = (offset + limit > totalPools) ? (totalPools - offset) : limit; + + // Create result array of appropriate size + pools = new PartyPool[](itemsToReturn); + + // Fill the result array + for (uint256 i = 0; i < itemsToReturn; i++) { + pools[i] = _allPools[offset + i]; + } + + return pools; + } + + /// @inheritdoc IPartyPlanner + function tokenCount() external view returns (uint256) { + return _allTokens.length; + } + + /// @inheritdoc IPartyPlanner + function getAllTokens(uint256 offset, uint256 limit) external view returns (address[] memory tokens) { + uint256 totalTokens = _allTokens.length; + + // If offset is beyond array bounds, return empty array + if (offset >= totalTokens) { + return new address[](0); + } + + // Calculate actual number of tokens to return (respecting bounds) + uint256 itemsToReturn = (offset + limit > totalTokens) ? (totalTokens - offset) : limit; + + // Create result array of appropriate size + tokens = new address[](itemsToReturn); + + // Fill the result array + for (uint256 i = 0; i < itemsToReturn; i++) { + tokens[i] = address(_allTokens[offset + i]); + } + + return tokens; + } + + /// @inheritdoc IPartyPlanner + function poolsByTokenCount(IERC20 token) external view returns (uint256) { + return _poolsByToken[token].length; + } + + /// @inheritdoc IPartyPlanner + function getPoolsByToken(IERC20 token, uint256 offset, uint256 limit) external view returns (PartyPool[] memory pools) { + PartyPool[] storage tokenPools = _poolsByToken[token]; + uint256 totalPools = tokenPools.length; + + // If offset is beyond array bounds, return empty array + if (offset >= totalPools) { + return new PartyPool[](0); + } + + // Calculate actual number of pools to return (respecting bounds) + uint256 itemsToReturn = (offset + limit > totalPools) ? (totalPools - offset) : limit; + + // Create result array of appropriate size + pools = new PartyPool[](itemsToReturn); + + // Fill the result array + for (uint256 i = 0; i < itemsToReturn; i++) { + pools[i] = tokenPools[offset + i]; + } + + return pools; + } +} diff --git a/src/PartyPool.sol b/src/PartyPool.sol index 8726e77..ee903fe 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -37,13 +37,13 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { /// @notice Token addresses comprising the pool. Effectively immutable after construction. /// @dev tokens[i] corresponds to the i-th asset and maps to index i in the internal LMSR arrays. - address[] public tokens; // effectively immutable since there is no interface to change the tokens + IERC20[] public tokens; // effectively immutable since there is no interface to change the tokens /// @inheritdoc IPartyPool function numTokens() external view returns (uint256) { return tokens.length; } /// @inheritdoc IPartyPool - function allTokens() external view returns (address[] memory) { return tokens; } + function allTokens() external view returns (IERC20[] memory) { return tokens; } // NOTE that the slippage target is only exactly achieved in completely balanced pools where all assets are // priced the same. This target is actually a minimum slippage that the pool imposes on traders, and the actual @@ -84,7 +84,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { /// @notice Mapping from token address => (index+1). A zero value indicates the token is not in the pool. /// @dev Use index = tokenAddressToIndexPlusOne[token] - 1 when non-zero. - mapping(address=>uint) public tokenAddressToIndexPlusOne; // Uses index+1 so a result of 0 indicates a failed lookup + mapping(IERC20=>uint) public tokenAddressToIndexPlusOne; // Uses index+1 so a result of 0 indicates a failed lookup /// @notice Scale factor used when converting LMSR Q64.64 totals to LP token units (uint). /// @dev LP tokens are minted in units equal to ABDK.mulu(lastTotalQ64x64, LP_SCALE). @@ -102,7 +102,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { constructor( string memory name_, string memory symbol_, - address[] memory _tokens, + IERC20[] memory _tokens, uint256[] memory _bases, int128 _tradeFrac, int128 _targetSlippage, @@ -168,6 +168,120 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { return depositAmounts; } + /// @notice Initial mint to set up pool for the first time. + /// @dev Assumes tokens have already been transferred to the pool prior to calling. + /// Can only be called when the pool is uninitialized (totalSupply() == 0 or lmsr.nAssets == 0). + /// @param receiver address that receives the LP tokens + /// @param lpTokens The number of LP tokens to issue for this mint. If 0, then the number of tokens returned will equal the LMSR internal q total + function initialMint(address receiver, uint256 lpTokens) external nonReentrant + returns (uint256 lpMinted) { + uint256 n = tokens.length; + + // Check if this is initial deposit - revert if not + bool isInitialDeposit = totalSupply() == 0 || lmsr.nAssets == 0; + require(isInitialDeposit, "initialMint: pool already initialized"); + + // Update cached balances for all assets + int128[] memory newQInternal = new int128[](n); + uint256[] memory depositAmounts = new uint256[](n); + for (uint i = 0; i < n; ) { + uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); + cachedUintBalances[i] = bal; + newQInternal[i] = _uintToInternalFloor(bal, bases[i]); + depositAmounts[i] = bal; + unchecked { i++; } + } + + // Initialize the stabilized LMSR state + lmsr.init(newQInternal, tradeFrac, targetSlippage); + + // Compute actual LP tokens to mint based on size metric (scaled) + if( lpTokens != 0 ) + lpMinted = lpTokens; + else { + int128 newTotal = _computeSizeMetric(newQInternal); + lpMinted = ABDKMath64x64.mulu(newTotal, LP_SCALE); + } + + require(lpMinted > 0, "initialMint: zero LP amount"); + _mint(receiver, lpMinted); + emit Mint(address(0), receiver, depositAmounts, lpMinted); + } + + /// @notice Proportional mint for existing pool. + /// @dev Payer must approve the required token amounts before calling. + /// Can only be called when pool is already initialized (totalSupply() > 0 and lmsr.nAssets > 0). + /// Rounds follow the pool-favorable conventions documented in helpers (ceil inputs, floor outputs). + /// @param payer address that provides the input tokens + /// @param receiver address that receives the LP tokens + /// @param lpTokenAmount desired amount of LP tokens to mint + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external nonReentrant + returns (uint256 lpMinted) { + require(deadline == 0 || block.timestamp <= deadline, "mint: deadline exceeded"); + uint256 n = tokens.length; + + // Check if this is NOT initial deposit - revert if it is + bool isInitialDeposit = totalSupply() == 0 || lmsr.nAssets == 0; + require(!isInitialDeposit, "mint: use initialMint for pool initialization"); + require(lpTokenAmount > 0, "mint: zero LP amount"); + + // Capture old pool size metric (scaled) by computing from current balances + int128 oldTotal = _computeSizeMetric(lmsr.qInternal); + uint256 oldScaled = ABDKMath64x64.mulu(oldTotal, LP_SCALE); + + // Calculate required deposit amounts for the desired LP tokens + uint256[] memory depositAmounts = mintDepositAmounts(lpTokenAmount); + + // Transfer in all token amounts + for (uint i = 0; i < n; ) { + if (depositAmounts[i] > 0) { + tokens[i].safeTransferFrom(payer, address(this), depositAmounts[i]); + } + unchecked { i++; } + } + + // Update cached balances for all assets + int128[] memory newQInternal = new int128[](n); + for (uint i = 0; i < n; ) { + uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); + cachedUintBalances[i] = bal; + newQInternal[i] = _uintToInternalFloor(bal, bases[i]); + unchecked { i++; } + } + + // Update for proportional change + lmsr.updateForProportionalChange(newQInternal); + + // Compute actual LP tokens to mint based on change in size metric (scaled) + // floor truncation rounds in favor of the pool + int128 newTotal = _computeSizeMetric(newQInternal); + uint256 newScaled = ABDKMath64x64.mulu(newTotal, LP_SCALE); + uint256 actualLpToMint; + + require(oldScaled > 0, "mint: oldScaled zero"); + uint256 delta = (newScaled > oldScaled) ? (newScaled - oldScaled) : 0; + // Proportional issuance: totalSupply * delta / oldScaled + if (delta > 0) { + // floor truncation rounds in favor of the pool + actualLpToMint = (totalSupply() * delta) / oldScaled; + } else { + actualLpToMint = 0; + } + + // Ensure the calculated LP amount is not too different from requested + require(actualLpToMint > 0, "mint: zero LP minted"); + + // Allow actual amount to be at most 0.00001% less than requested + // This accounts for rounding in deposit calculations + uint256 minAcceptable = lpTokenAmount * 99_999 / 100_000; + require(actualLpToMint >= minAcceptable, "mint: insufficient LP minted"); + + _mint(receiver, actualLpToMint); + emit Mint(payer, receiver, depositAmounts, actualLpToMint); + return actualLpToMint; + } + /// @inheritdoc IPartyPool function burnReceiveAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts) { return _burnReceiveAmounts(lpTokenAmount); @@ -194,106 +308,6 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { return withdrawAmounts; } - /// @notice Proportional mint (or initial supply if first call). - /// @dev - For initial supply: assumes tokens have already been transferred to the pool prior to calling. - /// - For subsequent mints: payer must approve the required token amounts before calling. - /// Rounds follow the pool-favorable conventions documented in helpers (ceil inputs, floor outputs). - /// @param payer address that provides the input tokens (ignored for initial deposit) - /// @param receiver address that receives the LP tokens - /// @param lpTokenAmount desired amount of LP tokens to mint (ignored for initial deposit) - /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external nonReentrant { - require(deadline == 0 || block.timestamp <= deadline, "mint: deadline exceeded"); - uint256 n = tokens.length; - // Check if this is initial deposit - bool isInitialDeposit = totalSupply() == 0 || lmsr.nAssets == 0; - - require(lpTokenAmount > 0 || isInitialDeposit, "mint: zero LP amount"); - - // Capture old pool size metric (scaled) by computing from current balances - uint256 oldScaled = 0; - if (!isInitialDeposit) { - int128 oldTotal = _computeSizeMetric(lmsr.qInternal); - oldScaled = ABDKMath64x64.mulu(oldTotal, LP_SCALE); - } - - // For non-initial deposits, transfer tokens from payer - uint256[] memory depositAmounts = new uint256[](n); - - if (!isInitialDeposit) { - // Calculate required deposit amounts for the desired LP tokens - depositAmounts = mintDepositAmounts(lpTokenAmount); - - // Transfer in all token amounts - for (uint i = 0; i < n; ) { - if (depositAmounts[i] > 0) { - _safeTransferFrom(tokens[i], payer, address(this), depositAmounts[i]); - } - unchecked { i++; } - } - } - - // Update cached balances for all assets - int128[] memory newQInternal = new int128[](n); - for (uint i = 0; i < n; ) { - uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); - cachedUintBalances[i] = bal; - newQInternal[i] = _uintToInternalFloor(bal, bases[i]); - - // For initial deposit, record the actual deposited amounts - if (isInitialDeposit) { - depositAmounts[i] = bal; - } - - unchecked { i++; } - } - - // If first time, call init, otherwise update proportional change. - if (isInitialDeposit) { - // Initialize the stabilized LMSR state - lmsr.init(newQInternal, tradeFrac, targetSlippage); - } else { - // Update for proportional change - lmsr.updateForProportionalChange(newQInternal); - } - - // Compute actual LP tokens to mint based on change in size metric (scaled) - // floor truncation rounds in favor of the pool - int128 newTotal = _computeSizeMetric(newQInternal); - uint256 newScaled = ABDKMath64x64.mulu(newTotal, LP_SCALE); - uint256 actualLpToMint; - - if (isInitialDeposit) { - // Initial provisioning: mint newScaled (as LP units) - actualLpToMint = newScaled; - } else { - require(oldScaled > 0, "mint: oldScaled zero"); - uint256 delta = (newScaled > oldScaled) ? (newScaled - oldScaled) : 0; - // Proportional issuance: totalSupply * delta / oldScaled - if (delta > 0) { - // floor truncation rounds in favor of the pool - actualLpToMint = (totalSupply() * delta) / oldScaled; - } else { - actualLpToMint = 0; - } - } - - // For subsequent mints, ensure the calculated LP amount is not too different from requested - if (!isInitialDeposit) { - // Allow for some rounding error but ensure we're not far off from requested amount - require(actualLpToMint > 0, "mint: zero LP minted"); - - // Allow actual amount to be at most 0.00001% less than requested - // This accounts for rounding in deposit calculations - uint256 minAcceptable = lpTokenAmount * 99_999 / 100_000; - require(actualLpToMint >= minAcceptable, "mint: insufficient LP minted"); - } - - require( actualLpToMint > 0, "mint: zero LP amount"); - _mint(receiver, actualLpToMint); - emit Mint(payer, receiver, depositAmounts, actualLpToMint); - } - /// @notice Burn LP tokens and withdraw the proportional basket to receiver. /// @dev Payer must own or approve the LP tokens being burned. The function updates LMSR state /// proportionally to reflect the reduced pool size after the withdrawal. @@ -324,7 +338,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { // Transfer underlying tokens out to receiver according to computed proportions for (uint i = 0; i < n; ) { if (withdrawAmounts[i] > 0) { - _safeTransfer(tokens[i], receiver, withdrawAmounts[i]); + tokens[i].safeTransfer(receiver, withdrawAmounts[i]); } unchecked { i++; } } @@ -369,6 +383,139 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { Swaps ---------------------- */ + /// @inheritdoc IPartyPool + function swapAmounts( + uint256 inputTokenIndex, + uint256 outputTokenIndex, + uint256 maxAmountIn, + int128 limitPrice + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); + return (grossIn, outUint, feeUint); + } + + /// @inheritdoc IPartyPool + function swapToLimitAmounts( + uint256 inputTokenIndex, + uint256 outputTokenIndex, + int128 limitPrice + ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice); + return (grossIn, outUint, feeUint); + } + + + /// @notice Swap input token i -> token j. Payer must approve token i. + /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. + /// Non-standard tokens (fee-on-transfer, rebasers) are rejected via balance checks. + /// @param payer address of the account that pays for the swap + /// @param receiver address that will receive the output tokens + /// @param inputTokenIndex index of input asset + /// @param outputTokenIndex index of output asset + /// @param maxAmountIn maximum amount of token i (uint256) to transfer in (inclusive of fees) + /// @param limitPrice maximum acceptable marginal price (64.64 fixed point). Pass 0 to ignore. + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256), fee fee taken from the input (uint256) + function swap( + address payer, + address receiver, + uint256 inputTokenIndex, + uint256 outputTokenIndex, + uint256 maxAmountIn, + int128 limitPrice, + uint256 deadline + ) external nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + uint256 n = tokens.length; + require(inputTokenIndex < n && outputTokenIndex < n, "swap: idx"); + require(maxAmountIn > 0, "swap: input zero"); + require(deadline == 0 || block.timestamp <= deadline, "swap: deadline exceeded"); + + // Read previous balances for affected assets + uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); + uint256 prevBalJ = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); + + // Compute amounts using the same path as views + (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalUsed, int128 amountOutInternal, , uint256 feeUint) = + _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); + + // Transfer the exact amount from payer and require exact receipt (revert on fee-on-transfer) + tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransferAmount); + uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); + require(balIAfter == prevBalI + totalTransferAmount, "swap: non-standard tokenIn"); + + // Transfer output to receiver and verify exact decrease + tokens[outputTokenIndex].safeTransfer(receiver, amountOutUint); + uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); + require(balJAfter == prevBalJ - amountOutUint, "swap: non-standard tokenOut"); + + // Update cached uint balances for i and j using actual balances + cachedUintBalances[inputTokenIndex] = balIAfter; + cachedUintBalances[outputTokenIndex] = balJAfter; + + // Apply swap to LMSR state with the internal amounts actually used + lmsr.applySwap(inputTokenIndex, outputTokenIndex, amountInInternalUsed, amountOutInternal); + + emit Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], totalTransferAmount, amountOutUint); + + return (totalTransferAmount, amountOutUint, feeUint); + } + + /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. + /// @dev If balances prevent fully reaching the limit, the function caps and returns actuals. + /// The payer must transfer the exact gross input computed by the view. + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function swapToLimit( + address payer, + address receiver, + uint256 inputTokenIndex, + uint256 outputTokenIndex, + int128 limitPrice, + uint256 deadline + ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { + uint256 n = tokens.length; + require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); + require(limitPrice > int128(0), "swapToLimit: limit <= 0"); + require(deadline == 0 || block.timestamp <= deadline, "swapToLimit: deadline exceeded"); + + // Read previous balances for affected assets + uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); + uint256 prevBalJ = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); + + // Compute amounts using the same path as views + (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalMax, int128 amountOutInternal, uint256 amountInUsedUint, uint256 feeUint) = + _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice); + + // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) + tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransferAmount); + uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); + require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); + + // Transfer output to receiver and verify exact decrease + tokens[outputTokenIndex].safeTransfer(receiver, amountOutUint); + uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); + require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); + + // Update caches to actual balances + cachedUintBalances[inputTokenIndex] = balIAfter; + cachedUintBalances[outputTokenIndex] = balJAfter; + + // Apply swap to LMSR state with the internal amounts + lmsr.applySwap(inputTokenIndex, outputTokenIndex, amountInInternalMax, amountOutInternal); + + // Maintain original event semantics (logs input without fee) + emit Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], amountInUsedUint, amountOutUint); + + return (amountInUsedUint, amountOutUint, feeUint); + } + + /// @notice Ceiling fee helper: computes ceil(x * feePpm / 1_000_000) + /// @dev Internal helper; public-facing functions use this to ensure fees round up in favor of pool. + function _ceilFee(uint256 x, uint256 feePpm) internal pure returns (uint256) { + if (feePpm == 0) return 0; + // ceil division: (num + denom - 1) / denom + return (x * feePpm + 1_000_000 - 1) / 1_000_000; + } + /// @notice Internal quote for exact-input swap that mirrors swap() rounding and fee application /// @dev Returns amounts consistent with swap() semantics: grossIn includes fees (ceil), amountOut is floored. /// @return grossIn amount to transfer in (inclusive of fee), amountOutUint output amount (uint), @@ -474,139 +621,6 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { require(amountOutUint > 0, "swapToLimit: output zero"); } - /// @inheritdoc IPartyPool - function swapAmounts( - uint256 inputTokenIndex, - uint256 outputTokenIndex, - uint256 maxAmountIn, - int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { - (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); - return (grossIn, outUint, feeUint); - } - - /// @inheritdoc IPartyPool - function swapToLimitAmounts( - uint256 inputTokenIndex, - uint256 outputTokenIndex, - int128 limitPrice - ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee) { - (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice); - return (grossIn, outUint, feeUint); - } - - - /// @notice Swap input token i -> token j. Payer must approve token i. - /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. - /// Non-standard tokens (fee-on-transfer, rebasers) are rejected via balance checks. - /// @param payer address of the account that pays for the swap - /// @param receiver address that will receive the output tokens - /// @param inputTokenIndex index of input asset - /// @param outputTokenIndex index of output asset - /// @param maxAmountIn maximum amount of token i (uint256) to transfer in (inclusive of fees) - /// @param limitPrice maximum acceptable marginal price (64.64 fixed point). Pass 0 to ignore. - /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256), fee fee taken from the input (uint256) - function swap( - address payer, - address receiver, - uint256 inputTokenIndex, - uint256 outputTokenIndex, - uint256 maxAmountIn, - int128 limitPrice, - uint256 deadline - ) external nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { - uint256 n = tokens.length; - require(inputTokenIndex < n && outputTokenIndex < n, "swap: idx"); - require(maxAmountIn > 0, "swap: input zero"); - require(deadline == 0 || block.timestamp <= deadline, "swap: deadline exceeded"); - - // Read previous balances for affected assets - uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - uint256 prevBalJ = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); - - // Compute amounts using the same path as views - (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalUsed, int128 amountOutInternal, , uint256 feeUint) = - _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); - - // Transfer the exact amount from payer and require exact receipt (revert on fee-on-transfer) - _safeTransferFrom(tokens[inputTokenIndex], payer, address(this), totalTransferAmount); - uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - require(balIAfter == prevBalI + totalTransferAmount, "swap: non-standard tokenIn"); - - // Transfer output to receiver and verify exact decrease - _safeTransfer(tokens[outputTokenIndex], receiver, amountOutUint); - uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); - require(balJAfter == prevBalJ - amountOutUint, "swap: non-standard tokenOut"); - - // Update cached uint balances for i and j using actual balances - cachedUintBalances[inputTokenIndex] = balIAfter; - cachedUintBalances[outputTokenIndex] = balJAfter; - - // Apply swap to LMSR state with the internal amounts actually used - lmsr.applySwap(inputTokenIndex, outputTokenIndex, amountInInternalUsed, amountOutInternal); - - emit Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], totalTransferAmount, amountOutUint); - - return (totalTransferAmount, amountOutUint, feeUint); - } - - /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. - /// @dev If balances prevent fully reaching the limit, the function caps and returns actuals. - /// The payer must transfer the exact gross input computed by the view. - /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. - function swapToLimit( - address payer, - address receiver, - uint256 inputTokenIndex, - uint256 outputTokenIndex, - int128 limitPrice, - uint256 deadline - ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { - uint256 n = tokens.length; - require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); - require(limitPrice > int128(0), "swapToLimit: limit <= 0"); - require(deadline == 0 || block.timestamp <= deadline, "swapToLimit: deadline exceeded"); - - // Read previous balances for affected assets - uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - uint256 prevBalJ = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); - - // Compute amounts using the same path as views - (uint256 totalTransferAmount, uint256 amountOutUint, int128 amountInInternalMax, int128 amountOutInternal, uint256 amountInUsedUint, uint256 feeUint) = - _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice); - - // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) - _safeTransferFrom(tokens[inputTokenIndex], payer, address(this), totalTransferAmount); - uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); - - // Transfer output to receiver and verify exact decrease - _safeTransfer(tokens[outputTokenIndex], receiver, amountOutUint); - uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); - require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); - - // Update caches to actual balances - cachedUintBalances[inputTokenIndex] = balIAfter; - cachedUintBalances[outputTokenIndex] = balJAfter; - - // Apply swap to LMSR state with the internal amounts - lmsr.applySwap(inputTokenIndex, outputTokenIndex, amountInInternalMax, amountOutInternal); - - // Maintain original event semantics (logs input without fee) - emit Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], amountInUsedUint, amountOutUint); - - return (amountInUsedUint, amountOutUint, feeUint); - } - - /// @notice Ceiling fee helper: computes ceil(x * feePpm / 1_000_000) - /// @dev Internal helper; public-facing functions use this to ensure fees round up in favor of pool. - function _ceilFee(uint256 x, uint256 feePpm) internal pure returns (uint256) { - if (feePpm == 0) return 0; - // ceil division: (num + denom - 1) / denom - return (x * feePpm + 1_000_000 - 1) / 1_000_000; - } - /// @notice Compute fee and net amounts for a gross input (fee rounded up to favor the pool). /// @return feeUint fee taken (uint) and netUint remaining for protocol use (uint) function _computeFee(uint256 gross) internal view returns (uint256 feeUint, uint256 netUint) { @@ -673,7 +687,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { // Record pre-balance and transfer tokens from payer, require exact receipt (revert on fee-on-transfer) uint256 prevBalI = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); - _safeTransferFrom(tokens[inputTokenIndex], payer, address(this), totalTransfer); + tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransfer); uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); require(balIAfter == prevBalI + totalTransfer, "swapMint: non-standard tokenIn"); @@ -765,7 +779,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { require(amountOutUint > 0, "burnSwap: output zero"); // Transfer the payout to receiver - _safeTransfer(tokens[inputTokenIndex], receiver, amountOutUint); + tokens[inputTokenIndex].safeTransfer(receiver, amountOutUint); // Burn LP tokens from payer (authorization via allowance) if (msg.sender != payer) { @@ -853,7 +867,7 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { initialBalances[i] = IERC20(tokens[i]).balanceOf(address(this)); // Transfer token to recipient - _safeTransfer(tokens[i], recipient, amount); + tokens[i].safeTransfer(recipient, amount); } } @@ -913,18 +927,6 @@ contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { return floorValue; } - /* ---------------------- - ERC20 helpers (minimal) - ---------------------- */ - - function _safeTransferFrom(address token, address from, address to, uint256 amt) internal { - IERC20(token).safeTransferFrom(from, to, amt); - } - - function _safeTransfer(address token, address to, uint256 amt) internal { - IERC20(token).safeTransfer(to, amt); - } - /// @notice Helper to compute size metric (sum of all asset quantities) from internal balances /// @dev Returns the sum of all provided qInternal_ entries as a Q64.64 value. function _computeSizeMetric(int128[] memory qInternal_) private pure returns (int128) { diff --git a/test/GasTest.sol b/test/GasTest.sol index 9e79cbc..a6acee7 100644 --- a/test/GasTest.sol +++ b/test/GasTest.sol @@ -25,9 +25,12 @@ contract FlashBorrower is IPartyFlashCallback { address public recipient; address[] public tokens; - constructor(address _pool, address[] memory _tokens) { + constructor(address _pool, IERC20[] memory _tokens) { pool = _pool; - tokens = _tokens; + tokens = new address[](_tokens.length); + for (uint i = 0; i < _tokens.length; i++) { + tokens[i] = address(_tokens[i]); + } } function setAction(Action _action, address _recipient) external { @@ -162,7 +165,11 @@ contract GasTest is Test { // Deploy pool with a small fee to test fee-handling paths (use 1000 ppm = 0.1%) uint256 feePpm = 1000; string memory poolName = string(abi.encodePacked("LP", vm.toString(numTokens))); - PartyPool newPool = new PartyPool(poolName, poolName, tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, false); + IERC20[] memory ierc20Tokens = new IERC20[](tokens.length); + for (uint i = 0; i < tokens.length; i++) { + ierc20Tokens[i] = IERC20(tokens[i]); + } + PartyPool newPool = new PartyPool(poolName, poolName, ierc20Tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, false); // Transfer initial deposit amounts into pool before initial mint for (uint256 i = 0; i < numTokens; i++) { @@ -170,7 +177,7 @@ contract GasTest is Test { } // Perform initial mint (initial deposit); receiver is this contract - newPool.mint(address(0), address(this), 0, 0); + newPool.initialMint(address(this), 0); return newPool; } @@ -197,7 +204,11 @@ contract GasTest is Test { uint256 feePpm = 1000; string memory poolName = string(abi.encodePacked("LPs", vm.toString(numTokens))); // Note the final 'true' arg to activate stable-pair optimization path - PartyPool newPool = new PartyPool(poolName, poolName, tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, true); + IERC20[] memory ierc20Tokens = new IERC20[](tokens.length); + for (uint i = 0; i < tokens.length; i++) { + ierc20Tokens[i] = IERC20(tokens[i]); + } + PartyPool newPool = new PartyPool(poolName, poolName, ierc20Tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, true); // Transfer initial deposit amounts into pool before initial mint for (uint256 i = 0; i < numTokens; i++) { @@ -205,7 +216,7 @@ contract GasTest is Test { } // Perform initial mint (initial deposit); receiver is this contract - newPool.mint(address(0), address(this), 0, 0); + newPool.initialMint(address(this), 0); return newPool; } @@ -228,7 +239,7 @@ contract GasTest is Test { /// @notice Setup a flash borrower for testing function setupFlashBorrower() internal returns (FlashBorrower borrower) { // Get token addresses from the 2-token pool - address[] memory tokenAddresses = pool2.allTokens(); + IERC20[] memory tokenAddresses = pool2.allTokens(); // Deploy the borrower contract borrower = new FlashBorrower(address(pool2), tokenAddresses); @@ -236,22 +247,22 @@ contract GasTest is Test { // Mint tokens to alice to be used for repayments and approve borrower vm.startPrank(alice); for (uint256 i = 0; i < tokenAddresses.length; i++) { - TestERC20(tokenAddresses[i]).mint(alice, INIT_BAL * 2); - TestERC20(tokenAddresses[i]).approve(address(borrower), type(uint256).max); + TestERC20(address(tokenAddresses[i])).mint(alice, INIT_BAL * 2); + TestERC20(address(tokenAddresses[i])).approve(address(borrower), type(uint256).max); } vm.stopPrank(); } /// @notice Helper function: perform 10 swaps back-and-forth between the first two tokens. function _performSwapGasTest(PartyPool testPool) internal { - address[] memory tokens = testPool.allTokens(); + IERC20[] memory tokens = testPool.allTokens(); require(tokens.length >= 2, "Pool must have at least 2 tokens"); // Ensure alice approves pool for both tokens vm.prank(alice); - TestERC20(tokens[0]).approve(address(testPool), type(uint256).max); + TestERC20(address(tokens[0])).approve(address(testPool), type(uint256).max); vm.prank(alice); - TestERC20(tokens[1]).approve(address(testPool), type(uint256).max); + TestERC20(address(tokens[1])).approve(address(testPool), type(uint256).max); uint256 maxIn = 1_000; @@ -310,13 +321,13 @@ contract GasTest is Test { function _performSwapMintBurnSwapGasTest(PartyPool testPool) internal { uint256 iterations = 10; uint256 input = 1_000; - address[] memory tokens = testPool.allTokens(); + IERC20[] memory tokens = testPool.allTokens(); // Top up alice so repeated operations won't fail - TestERC20(tokens[0]).mint(alice, iterations * input * 2); + TestERC20(address(tokens[0])).mint(alice, iterations * input * 2); vm.startPrank(alice); - TestERC20(tokens[0]).approve(address(testPool), type(uint256).max); + TestERC20(address(tokens[0])).approve(address(testPool), type(uint256).max); for (uint256 k = 0; k < iterations; k++) { // Mint LP by providing single-token input; receive LP minted @@ -355,14 +366,14 @@ contract GasTest is Test { function _performMintBurnGasTest(PartyPool testPool) internal { uint256 iterations = 50; uint256 input = 1_000; - address[] memory poolTokens = testPool.allTokens(); + IERC20[] memory poolTokens = testPool.allTokens(); vm.startPrank(alice); // Mint additional tokens to alice and approve pool to transfer tokens for proportional mint for (uint256 i = 0; i < poolTokens.length; i++) { - TestERC20(poolTokens[i]).mint(alice, iterations * input * 2); - TestERC20(poolTokens[i]).approve(address(testPool), type(uint256).max); + TestERC20(address(poolTokens[i])).mint(alice, iterations * input * 2); + TestERC20(address(poolTokens[i])).approve(address(testPool), type(uint256).max); } for (uint256 k = 0; k < iterations; k++) { @@ -422,7 +433,7 @@ contract GasTest is Test { borrower.setAction(FlashBorrower.Action.NORMAL, alice); // Create loan request for single token (get array size from pool) - address[] memory poolTokens = pool2.allTokens(); + IERC20[] memory poolTokens = pool2.allTokens(); uint256[] memory amounts = new uint256[](poolTokens.length); amounts[0] = 1000; @@ -440,7 +451,7 @@ contract GasTest is Test { borrower.setAction(FlashBorrower.Action.NORMAL, alice); // Create loan request for multiple tokens (get array size from pool) - address[] memory poolTokens = pool2.allTokens(); + IERC20[] memory poolTokens = pool2.allTokens(); uint256[] memory amounts = new uint256[](poolTokens.length); amounts[0] = 1000; amounts[1] = 2000; diff --git a/test/PartyPlanner.t.sol b/test/PartyPlanner.t.sol new file mode 100644 index 0000000..f647cb2 --- /dev/null +++ b/test/PartyPlanner.t.sol @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "forge-std/Test.sol"; +import "../src/PartyPlanner.sol"; +import "../src/PartyPool.sol"; +import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +// Mock ERC20 token for testing +contract MockERC20 is ERC20 { + uint8 private _decimals; + + constructor(string memory name, string memory symbol, uint8 decimals_) ERC20(name, symbol) { + _decimals = decimals_; + } + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } + + function decimals() public view override returns (uint8) { + return _decimals; + } +} + +contract PartyPlannerTest is Test { + PartyPlanner public planner; + MockERC20 public tokenA; + MockERC20 public tokenB; + MockERC20 public tokenC; + + address public payer = makeAddr("payer"); + address public receiver = makeAddr("receiver"); + + uint256 constant INITIAL_MINT_AMOUNT = 1000000e18; + uint256 constant INITIAL_DEPOSIT_AMOUNT = 1000e18; + + function setUp() public { + // Deploy PartyPlanner + planner = new PartyPlanner(); + + // Deploy mock tokens + tokenA = new MockERC20("Token A", "TKNA", 18); + tokenB = new MockERC20("Token B", "TKNB", 18); + tokenC = new MockERC20("Token C", "TKNC", 6); + + // Mint tokens to payer + tokenA.mint(payer, INITIAL_MINT_AMOUNT); + tokenB.mint(payer, INITIAL_MINT_AMOUNT); + tokenC.mint(payer, INITIAL_MINT_AMOUNT); + + // Approve tokens for PartyPlanner + vm.startPrank(payer); + tokenA.approve(address(planner), type(uint256).max); + tokenB.approve(address(planner), type(uint256).max); + tokenC.approve(address(planner), type(uint256).max); + vm.stopPrank(); + } + + function test_createPool_Success() public { + // Prepare pool parameters + string memory name = "Test Pool"; + string memory symbol = "TESTLP"; + IERC20[] memory tokens = new IERC20[](2); + tokens[0] = IERC20(address(tokenA)); + tokens[1] = IERC20(address(tokenB)); + + uint256[] memory bases = new uint256[](2); + bases[0] = 1e18; // 18 decimals + bases[1] = 1e18; // 18 decimals + + uint256[] memory initialDeposits = new uint256[](2); + initialDeposits[0] = INITIAL_DEPOSIT_AMOUNT; + initialDeposits[1] = INITIAL_DEPOSIT_AMOUNT; + + // Fixed point parameters (using simple values for testing) + int128 tradeFrac = int128((1 << 64) - 1); // slightly less than 1.0 in 64.64 fixed point + int128 targetSlippage = int128(1 << 62); // 0.25 in 64.64 fixed point + uint256 swapFeePpm = 3000; // 0.3% + uint256 flashFeePpm = 5000; // 0.5% + + uint256 initialPoolCount = planner.poolCount(); + uint256 initialTokenACount = planner.poolsByTokenCount(IERC20(address(tokenA))); + uint256 initialTokenBCount = planner.poolsByTokenCount(IERC20(address(tokenB))); + + // Create pool + (PartyPool pool, uint256 lpAmount) = planner.createPool( + name, + symbol, + tokens, + bases, + tradeFrac, + targetSlippage, + swapFeePpm, + flashFeePpm, + false, // not stable + payer, + receiver, + initialDeposits, + 1000e18, // initial LP amount + 0 // no deadline + ); + + // Verify pool was created + assertNotEq(address(pool), address(0), "Pool should be created"); + assertGt(lpAmount, 0, "LP tokens should be minted"); + + // Verify pool is indexed correctly + assertEq(planner.poolCount(), initialPoolCount + 1, "Pool count should increase by 1"); + assertTrue(planner.getPoolSupported(address(pool)), "Pool should be marked as supported"); + + // Verify token indexing + assertEq(planner.poolsByTokenCount(IERC20(address(tokenA))), initialTokenACount + 1, "TokenA pool count should increase"); + assertEq(planner.poolsByTokenCount(IERC20(address(tokenB))), initialTokenBCount + 1, "TokenB pool count should increase"); + + // Verify pools can be retrieved + PartyPool[] memory allPools = planner.getAllPools(0, 10); + bool poolFound = false; + for (uint256 i = 0; i < allPools.length; i++) { + if (allPools[i] == pool) { + poolFound = true; + break; + } + } + assertTrue(poolFound, "Created pool should be in getAllPools result"); + + // Verify pool appears in token-specific queries + PartyPool[] memory tokenAPools = planner.getPoolsByToken(IERC20(address(tokenA)), 0, 10); + bool poolInTokenA = false; + for (uint256 i = 0; i < tokenAPools.length; i++) { + if (tokenAPools[i] == pool) { + poolInTokenA = true; + break; + } + } + assertTrue(poolInTokenA, "Pool should be indexed under tokenA"); + + PartyPool[] memory tokenBPools = planner.getPoolsByToken(IERC20(address(tokenB)), 0, 10); + bool poolInTokenB = false; + for (uint256 i = 0; i < tokenBPools.length; i++) { + if (tokenBPools[i] == pool) { + poolInTokenB = true; + break; + } + } + assertTrue(poolInTokenB, "Pool should be indexed under tokenB"); + + // Verify LP tokens were minted to receiver + assertEq(pool.balanceOf(receiver), lpAmount, "Receiver should have LP tokens"); + } + + function test_createPool_MultiplePoolsIndexing() public { + // Create first pool with tokenA and tokenB + IERC20[] memory tokens1 = new IERC20[](2); + tokens1[0] = IERC20(address(tokenA)); + tokens1[1] = IERC20(address(tokenB)); + + uint256[] memory bases1 = new uint256[](2); + bases1[0] = 1e18; + bases1[1] = 1e18; + + uint256[] memory deposits1 = new uint256[](2); + deposits1[0] = INITIAL_DEPOSIT_AMOUNT; + deposits1[1] = INITIAL_DEPOSIT_AMOUNT; + + (PartyPool pool1,) = planner.createPool( + "Pool 1", "LP1", tokens1, bases1, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + payer, receiver, deposits1, 1000e18, 0 + ); + + // Create second pool with tokenB and tokenC + IERC20[] memory tokens2 = new IERC20[](2); + tokens2[0] = IERC20(address(tokenB)); + tokens2[1] = IERC20(address(tokenC)); + + uint256[] memory bases2 = new uint256[](2); + bases2[0] = 1e18; + bases2[1] = 1e6; // tokenC has 6 decimals + + uint256[] memory deposits2 = new uint256[](2); + deposits2[0] = INITIAL_DEPOSIT_AMOUNT; + deposits2[1] = INITIAL_DEPOSIT_AMOUNT / 1e12; // Adjust for 6 decimals + + (PartyPool pool2,) = planner.createPool( + "Pool 2", "LP2", tokens2, bases2, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + payer, receiver, deposits2, 1000e18, 0 + ); + + // Verify indexing + assertEq(planner.poolCount(), 2, "Should have 2 pools"); + assertEq(planner.tokenCount(), 3, "Should have 3 unique tokens"); + + // Verify token-pool relationships + assertEq(planner.poolsByTokenCount(IERC20(address(tokenA))), 1, "TokenA should be in 1 pool"); + assertEq(planner.poolsByTokenCount(IERC20(address(tokenB))), 2, "TokenB should be in 2 pools"); + assertEq(planner.poolsByTokenCount(IERC20(address(tokenC))), 1, "TokenC should be in 1 pool"); + + // Verify tokenB appears in both pools + PartyPool[] memory tokenBPools = planner.getPoolsByToken(IERC20(address(tokenB)), 0, 10); + assertEq(tokenBPools.length, 2, "TokenB should have 2 pools"); + + bool pool1Found = false; + bool pool2Found = false; + for (uint256 i = 0; i < tokenBPools.length; i++) { + if (tokenBPools[i] == pool1) pool1Found = true; + if (tokenBPools[i] == pool2) pool2Found = true; + } + assertTrue(pool1Found, "Pool1 should be in tokenB pools"); + assertTrue(pool2Found, "Pool2 should be in tokenB pools"); + } + + function test_createPool_InvalidInputs() public { + IERC20[] memory tokens = new IERC20[](2); + tokens[0] = IERC20(address(tokenA)); + tokens[1] = IERC20(address(tokenB)); + + uint256[] memory bases = new uint256[](2); + bases[0] = 1e18; + bases[1] = 1e18; + + uint256[] memory deposits = new uint256[](1); // Mismatched length + deposits[0] = INITIAL_DEPOSIT_AMOUNT; + + // Test token/deposit length mismatch + vm.expectRevert("Planner: tokens and deposits length mismatch"); + planner.createPool( + "Test Pool", "TESTLP", tokens, bases, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + payer, receiver, deposits, 1000e18, 0 + ); + + // Test zero payer address + uint256[] memory validDeposits = new uint256[](2); + validDeposits[0] = INITIAL_DEPOSIT_AMOUNT; + validDeposits[1] = INITIAL_DEPOSIT_AMOUNT; + + vm.expectRevert("Planner: payer cannot be zero address"); + planner.createPool( + "Test Pool", "TESTLP", tokens, bases, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + address(0), receiver, validDeposits, 1000e18, 0 + ); + + // Test zero receiver address + vm.expectRevert("Planner: receiver cannot be zero address"); + planner.createPool( + "Test Pool", "TESTLP", tokens, bases, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + payer, address(0), validDeposits, 1000e18, 0 + ); + + // Test deadline exceeded + // The default timestamp is 1 and 1-0 is 0 which means "ignore deadline," so we need to set a proper timestamp. + vm.warp(1000); + vm.expectRevert("Planner: deadline exceeded"); + planner.createPool( + "Test Pool", "TESTLP", tokens, bases, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + payer, receiver, validDeposits, 1000e18, block.timestamp - 1 + ); + } + + function test_poolIndexing_Pagination() public { + // Create multiple pools for pagination testing + uint256 numPools = 5; + PartyPool[] memory createdPools = new PartyPool[](numPools); + + for (uint256 i = 0; i < numPools; i++) { + IERC20[] memory tokens = new IERC20[](2); + tokens[0] = IERC20(address(tokenA)); + tokens[1] = IERC20(address(tokenB)); + + uint256[] memory bases = new uint256[](2); + bases[0] = 1e18; + bases[1] = 1e18; + + uint256[] memory deposits = new uint256[](2); + deposits[0] = INITIAL_DEPOSIT_AMOUNT; + deposits[1] = INITIAL_DEPOSIT_AMOUNT; + + (PartyPool pool,) = planner.createPool( + string(abi.encodePacked("Pool ", vm.toString(i))), + string(abi.encodePacked("LP", vm.toString(i))), + tokens, bases, + int128((1 << 64) - 1), int128(1 << 62), 3000, 5000, false, + payer, receiver, deposits, 1000e18, 0 + ); + + createdPools[i] = pool; + } + + assertEq(planner.poolCount(), numPools, "Should have created all pools"); + + // Test pagination - get first 3 pools + PartyPool[] memory page1 = planner.getAllPools(0, 3); + assertEq(page1.length, 3, "First page should have 3 pools"); + + // Test pagination - get next 2 pools + PartyPool[] memory page2 = planner.getAllPools(3, 3); + assertEq(page2.length, 2, "Second page should have 2 pools"); + + // Test pagination - offset beyond bounds + PartyPool[] memory emptyPage = planner.getAllPools(10, 3); + assertEq(emptyPage.length, 0, "Should return empty array for out of bounds offset"); + + // Verify all pools are accessible through pagination + PartyPool[] memory allPools = planner.getAllPools(0, 10); + assertEq(allPools.length, numPools, "Should return all pools"); + + for (uint256 i = 0; i < numPools; i++) { + assertEq(address(allPools[i]), address(createdPools[i]), "Pool order should be preserved"); + } + } +} diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index f3617a8..f4c6d53 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -182,10 +182,10 @@ contract PartyPoolTest is Test { targetSlippage = ABDKMath64x64.divu(10, 10_000); // 0.001 // Build arrays for pool constructor - address[] memory tokens = new address[](3); - tokens[0] = address(token0); - tokens[1] = address(token1); - tokens[2] = address(token2); + IERC20[] memory tokens = new IERC20[](3); + tokens[0] = IERC20(address(token0)); + tokens[1] = IERC20(address(token1)); + tokens[2] = IERC20(address(token2)); uint256[] memory bases = new uint256[](3); bases[0] = BASE; @@ -204,20 +204,20 @@ contract PartyPoolTest is Test { token2.transfer(address(pool), INIT_BAL); // Perform initial mint (initial deposit); receiver is this contract - pool.mint(address(0), address(this), 0, 0); + pool.initialMint(address(this), 0); // Set up pool10 with 10 tokens - address[] memory tokens10 = new address[](10); - tokens10[0] = address(token0); - tokens10[1] = address(token1); - tokens10[2] = address(token2); - tokens10[3] = address(token3); - tokens10[4] = address(token4); - tokens10[5] = address(token5); - tokens10[6] = address(token6); - tokens10[7] = address(token7); - tokens10[8] = address(token8); - tokens10[9] = address(token9); + IERC20[] memory tokens10 = new IERC20[](10); + tokens10[0] = IERC20(address(token0)); + tokens10[1] = IERC20(address(token1)); + tokens10[2] = IERC20(address(token2)); + tokens10[3] = IERC20(address(token3)); + tokens10[4] = IERC20(address(token4)); + tokens10[5] = IERC20(address(token5)); + tokens10[6] = IERC20(address(token6)); + tokens10[7] = IERC20(address(token7)); + tokens10[8] = IERC20(address(token8)); + tokens10[9] = IERC20(address(token9)); uint256[] memory bases10 = new uint256[](10); for (uint i = 0; i < 10; i++) { @@ -251,7 +251,7 @@ contract PartyPoolTest is Test { token9.transfer(address(pool10), INIT_BAL); // Perform initial mint for pool10 - pool10.mint(address(0), address(this), 0, 0); + pool10.initialMint(address(this), 0); // For later tests we will mint tokens to alice/bob as needed token0.mint(alice, INIT_BAL); @@ -349,7 +349,7 @@ contract PartyPoolTest is Test { token2.approve(address(pool), type(uint256).max); // Snapshot pool totals (simple value metric = sum of token uint balances since base==1 in tests) - address[] memory toks = pool.allTokens(); + IERC20[] memory toks = pool.allTokens(); uint256 n = toks.length; uint256 poolValueBefore = 0; for (uint i = 0; i < n; i++) { @@ -1210,4 +1210,171 @@ contract PartyPoolTest is Test { pool.flash(alice, wrongLengthAmounts, ""); } + /// @notice Test that passing nonzero lpTokens to initialMint doesn't affect swap results + /// compared to pools initialized with default lpTokens (0) + function testInitialMintCustomLpTokensDoesNotAffectSwaps() public { + // Create two identical pools with different initial LP amounts + IERC20[] memory tokens = new IERC20[](3); + tokens[0] = IERC20(address(token0)); + tokens[1] = IERC20(address(token1)); + tokens[2] = IERC20(address(token2)); + + uint256[] memory bases = new uint256[](3); + bases[0] = BASE; + bases[1] = BASE; + bases[2] = BASE; + + uint256 feePpm = 1000; + + // Pool with default initialization (lpTokens = 0) + PartyPool poolDefault = new PartyPool("LP_DEFAULT", "LP_DEFAULT", tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, false); + + // Pool with custom initialization (lpTokens = custom amount) + PartyPool poolCustom = new PartyPool("LP_CUSTOM", "LP_CUSTOM", tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, false); + + // Mint additional tokens for both pools + token0.mint(address(this), INIT_BAL * 2); + token1.mint(address(this), INIT_BAL * 2); + token2.mint(address(this), INIT_BAL * 2); + + // Transfer identical amounts to both pools + token0.transfer(address(poolDefault), INIT_BAL); + token1.transfer(address(poolDefault), INIT_BAL); + token2.transfer(address(poolDefault), INIT_BAL); + + token0.transfer(address(poolCustom), INIT_BAL); + token1.transfer(address(poolCustom), INIT_BAL); + token2.transfer(address(poolCustom), INIT_BAL); + + // Initialize poolDefault with lpTokens = 0 (default behavior) + uint256 lpDefault = poolDefault.initialMint(address(this), 0); + + // Initialize poolCustom with custom lpTokens amount (5x the default) + uint256 customLpAmount = lpDefault * 5; + uint256 lpCustom = poolCustom.initialMint(address(this), customLpAmount); + + // Verify the custom pool has the expected LP supply + assertEq(lpCustom, customLpAmount, "Custom pool should have expected LP amount"); + assertEq(poolCustom.totalSupply(), customLpAmount, "Custom pool total supply should match"); + + // Both pools should have identical token balances + assertEq(token0.balanceOf(address(poolDefault)), token0.balanceOf(address(poolCustom)), "Token0 balances should match"); + assertEq(token1.balanceOf(address(poolDefault)), token1.balanceOf(address(poolCustom)), "Token1 balances should match"); + assertEq(token2.balanceOf(address(poolDefault)), token2.balanceOf(address(poolCustom)), "Token2 balances should match"); + + // Prepare Alice for swapping + token0.mint(alice, INIT_BAL); + token1.mint(alice, INIT_BAL); + + // Test identical swaps produce identical results + uint256 swapAmount = 10_000; + + vm.startPrank(alice); + token0.approve(address(poolDefault), type(uint256).max); + token0.approve(address(poolCustom), type(uint256).max); + + // Perform identical swaps: token0 -> token1 + (uint256 amountInDefault, uint256 amountOutDefault, uint256 feeDefault) = poolDefault.swap(alice, alice, 0, 1, swapAmount, 0, 0); + (uint256 amountInCustom, uint256 amountOutCustom, uint256 feeCustom) = poolCustom.swap(alice, alice, 0, 1, swapAmount, 0, 0); + + // Swap results should be identical + assertEq(amountInDefault, amountInCustom, "Swap input amounts should be identical"); + assertEq(amountOutDefault, amountOutCustom, "Swap output amounts should be identical"); + assertEq(feeDefault, feeCustom, "Swap fees should be identical"); + + vm.stopPrank(); + } + + /// @notice Test that minting the same proportion in pools with different initial LP amounts + /// returns correctly scaled LP tokens + function testProportionalMintingScaledByInitialAmount() public { + // Create two identical pools with different initial LP amounts + IERC20[] memory tokens = new IERC20[](3); + tokens[0] = IERC20(address(token0)); + tokens[1] = IERC20(address(token1)); + tokens[2] = IERC20(address(token2)); + + uint256[] memory bases = new uint256[](3); + bases[0] = BASE; + bases[1] = BASE; + bases[2] = BASE; + + uint256 feePpm = 1000; + + PartyPool poolDefault = new PartyPool("LP_DEFAULT", "LP_DEFAULT", tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, false); + PartyPool poolCustom = new PartyPool("LP_CUSTOM", "LP_CUSTOM", tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm, false); + + // Mint additional tokens + token0.mint(address(this), INIT_BAL * 4); + token1.mint(address(this), INIT_BAL * 4); + token2.mint(address(this), INIT_BAL * 4); + + // Transfer identical amounts to both pools + token0.transfer(address(poolDefault), INIT_BAL); + token1.transfer(address(poolDefault), INIT_BAL); + token2.transfer(address(poolDefault), INIT_BAL); + + token0.transfer(address(poolCustom), INIT_BAL); + token1.transfer(address(poolCustom), INIT_BAL); + token2.transfer(address(poolCustom), INIT_BAL); + + // Initialize pools with different LP amounts + uint256 lpDefault = poolDefault.initialMint(address(this), 0); + uint256 scaleFactor = 3; + uint256 customLpAmount = lpDefault * scaleFactor; + uint256 lpCustom = poolCustom.initialMint(address(this), customLpAmount); + + // Verify initial LP supplies + assertEq(poolDefault.totalSupply(), lpDefault, "Default pool should have default LP supply"); + assertEq(poolCustom.totalSupply(), customLpAmount, "Custom pool should have custom LP supply"); + + // Prepare Alice for minting + token0.mint(alice, INIT_BAL * 2); + token1.mint(alice, INIT_BAL * 2); + token2.mint(alice, INIT_BAL * 2); + + // Test proportional minting: mint 10% of each pool's supply + uint256 mintPercentage = 10; // 10% + uint256 lpRequestDefault = poolDefault.totalSupply() * mintPercentage / 100; + uint256 lpRequestCustom = poolCustom.totalSupply() * mintPercentage / 100; + + vm.startPrank(alice); + + // Approve tokens for both pools + token0.approve(address(poolDefault), type(uint256).max); + token1.approve(address(poolDefault), type(uint256).max); + token2.approve(address(poolDefault), type(uint256).max); + token0.approve(address(poolCustom), type(uint256).max); + token1.approve(address(poolCustom), type(uint256).max); + token2.approve(address(poolCustom), type(uint256).max); + + // Get required deposit amounts for both pools + uint256[] memory depositsDefault = poolDefault.mintDepositAmounts(lpRequestDefault); + uint256[] memory depositsCustom = poolCustom.mintDepositAmounts(lpRequestCustom); + + // Deposits should be identical (same proportion of identical balances) + assertEq(depositsDefault[0], depositsCustom[0], "Token0 deposits should be identical"); + assertEq(depositsDefault[1], depositsCustom[1], "Token1 deposits should be identical"); + assertEq(depositsDefault[2], depositsCustom[2], "Token2 deposits should be identical"); + + // Perform the mints + uint256 mintedDefault = poolDefault.mint(alice, alice, lpRequestDefault, 0); + uint256 mintedCustom = poolCustom.mint(alice, alice, lpRequestCustom, 0); + + // Minted LP amounts should be scaled by the same factor as initial supplies + uint256 expectedRatio = (mintedCustom * 1000) / mintedDefault; // Use fixed point for precision + uint256 actualRatio = (scaleFactor * 1000); + + // Allow small rounding differences (within 0.1%) + uint256 tolerance = actualRatio / 1000; // 0.1% tolerance + assertTrue(expectedRatio >= actualRatio - tolerance && expectedRatio <= actualRatio + tolerance, + "Minted LP ratio should match scale factor within tolerance"); + + // Verify Alice received the expected LP amounts + assertTrue(poolDefault.balanceOf(alice) >= mintedDefault, "Alice should receive default LP"); + assertTrue(poolCustom.balanceOf(alice) >= mintedCustom, "Alice should receive custom LP"); + + vm.stopPrank(); + } + }