diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index 09647a7..4e8418c 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -87,7 +87,7 @@ interface IPartyPool is IERC20Metadata { /// because the initial deposit is handled by transferring tokens then calling mint(). /// @param lpTokenAmount The amount of LP tokens desired /// @return depositAmounts Array of token amounts to deposit (rounded up) - function mintDepositAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory depositAmounts); + function mintAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory depositAmounts); /// @notice Proportional mint (or initial supply if first call). /// @dev - For initial supply: assumes tokens have already been transferred to the pool prior to calling. @@ -105,7 +105,7 @@ interface IPartyPool is IERC20Metadata { /// If the pool is uninitialized or supply is zero, returns zeros. /// @param lpTokenAmount The amount of LP tokens to burn /// @return withdrawAmounts Array of token amounts to withdraw (rounded down) - function burnReceiveAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts); + function burnAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts); /// @notice Burn LP tokens and withdraw the proportional basket to receiver. /// @dev Payer must own or approve the LP tokens being burned. The function updates LMSR state diff --git a/src/PartyPool.sol b/src/PartyPool.sol index 2c74100..7f44752 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -126,11 +126,11 @@ contract PartyPool is PartyPoolBase, IPartyPool { ---------------------- */ /// @inheritdoc IPartyPool - function mintDepositAmounts(uint256 lpTokenAmount) public view returns (uint256[] memory depositAmounts) { - return _mintDepositAmounts(lpTokenAmount); + function mintAmounts(uint256 lpTokenAmount) public view returns (uint256[] memory depositAmounts) { + return _mintAmounts(lpTokenAmount); } - function _mintDepositAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory depositAmounts) { + function _mintAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory depositAmounts) { uint256 n = tokens.length; depositAmounts = new uint256[](n); @@ -216,11 +216,11 @@ contract PartyPool is PartyPoolBase, IPartyPool { } /// @inheritdoc IPartyPool - function burnReceiveAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts) { - return _burnReceiveAmounts(lpTokenAmount); + function burnAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts) { + return _burnAmounts(lpTokenAmount); } - function _burnReceiveAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory withdrawAmounts) { + function _burnAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory withdrawAmounts) { uint256 n = tokens.length; withdrawAmounts = new uint256[](n); diff --git a/src/PartyPoolMintImpl.sol b/src/PartyPoolMintImpl.sol index cef3a62..3aa7ddc 100644 --- a/src/PartyPoolMintImpl.sol +++ b/src/PartyPoolMintImpl.sol @@ -43,7 +43,7 @@ contract PartyPoolMintImpl is PartyPoolBase { uint256 oldScaled = ABDKMath64x64.mulu(oldTotal, LP_SCALE); // Calculate required deposit amounts for the desired LP tokens - uint256[] memory depositAmounts = _mintDepositAmounts(lpTokenAmount); + uint256[] memory depositAmounts = _mintAmounts(lpTokenAmount, lmsr.nAssets, totalSupply()); // Transfer in all token amounts for (uint i = 0; i < n; ) { @@ -119,7 +119,7 @@ contract PartyPoolMintImpl is PartyPoolBase { } // Compute proportional withdrawal amounts for the requested LP amount (rounded down) - uint256[] memory withdrawAmounts = _burnReceiveAmounts(lpAmount); + uint256[] memory withdrawAmounts = _burnAmounts(lpAmount); // Transfer underlying tokens out to receiver according to computed proportions for (uint i = 0; i < n; ) { @@ -165,34 +165,33 @@ contract PartyPoolMintImpl is PartyPoolBase { emit Burn(payer, receiver, withdrawAmounts, lpAmount); } - /// @notice Internal helper to calculate required deposit amounts for minting LP tokens - function _mintDepositAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory depositAmounts) { - uint256 n = tokens.length; - depositAmounts = new uint256[](n); + function mintAmounts(uint256 lpTokenAmount, uint256 numAssets, uint256 totalSupply) public view returns (uint256[] memory depositAmounts) { + return _mintAmounts(lpTokenAmount, numAssets, totalSupply); + } + + function _mintAmounts(uint256 lpTokenAmount, uint256 numAssets, uint256 totalSupply) internal view returns (uint256[] memory depositAmounts) { + depositAmounts = new uint256[](numAssets); // If this is the first mint or pool is empty, return zeros // For first mint, tokens should already be transferred to the pool - if (totalSupply() == 0 || lmsr.nAssets == 0) { + if (totalSupply == 0 || numAssets == 0) { return depositAmounts; // Return zeros, initial deposit handled differently } - // Calculate deposit based on current proportions - uint256 totalLpSupply = totalSupply(); - // lpTokenAmount / totalLpSupply = depositAmount / currentBalance // Therefore: depositAmount = (lpTokenAmount * currentBalance) / totalLpSupply // We round up to protect the pool - for (uint i = 0; i < n; i++) { + for (uint i = 0; i < numAssets; i++) { uint256 currentBalance = cachedUintBalances[i]; // Calculate with rounding up: (a * b + c - 1) / c - depositAmounts[i] = (lpTokenAmount * currentBalance + totalLpSupply - 1) / totalLpSupply; + depositAmounts[i] = (lpTokenAmount * currentBalance + totalSupply - 1) / totalSupply; } return depositAmounts; } /// @notice Internal helper to calculate withdrawal amounts for burning LP tokens - function _burnReceiveAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory withdrawAmounts) { + function _burnAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory withdrawAmounts) { uint256 n = tokens.length; withdrawAmounts = new uint256[](n); diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index 0f9f484..7036816 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -324,7 +324,7 @@ contract PartyPoolTest is Test { token2.approve(address(pool), type(uint256).max); // Inspect the deposit amounts that the pool will require (these are rounded up) - uint256[] memory deposits = pool.mintDepositAmounts(1); + uint256[] memory deposits = pool.mintAmounts(1); // Basic sanity: deposits array length must match token count and not all zero necessarily assertEq(deposits.length, 3); @@ -366,7 +366,7 @@ contract PartyPoolTest is Test { uint256 totalLpBefore = pool.totalSupply(); // Compute required deposits and perform mint for 1 wei - uint256[] memory deposits = pool.mintDepositAmounts(1); + uint256[] memory deposits = pool.mintAmounts(1); // Sum deposits as deposited_value uint256 depositedValue = 0; @@ -400,14 +400,14 @@ contract PartyPoolTest is Test { vm.stopPrank(); } - /// @notice mintDepositAmounts should round up deposit amounts to protect the pool. + /// @notice mintAmounts should round up deposit amounts to protect the pool. function testMintDepositAmountsRoundingUp() public view { uint256 totalLp = pool.totalSupply(); assertTrue(totalLp > 0, "precondition: total supply > 0"); // Request half of LP supply uint256 want = totalLp / 2; - uint256[] memory deposits = pool.mintDepositAmounts(want); + uint256[] memory deposits = pool.mintAmounts(want); // We expect each deposit to be roughly half the pool balance, but due to rounding up it should satisfy: // deposits[i] * 2 >= cached balance (i.e., rounding up) @@ -424,7 +424,7 @@ contract PartyPoolTest is Test { assertTrue(totalLp > 0, "precondition: LP > 0"); // Compute amounts required to redeem entire supply (should be current balances) - uint256[] memory withdrawAmounts = pool.burnReceiveAmounts(totalLp); + uint256[] memory withdrawAmounts = pool.burnAmounts(totalLp); // Sanity: withdrawAmounts should equal pool balances (or very close due to rounding) for (uint i = 0; i < withdrawAmounts.length; i++) { @@ -514,7 +514,7 @@ contract PartyPoolTest is Test { } - /// @notice Verify mintDepositAmounts matches the actual token transfers performed by mint() + /// @notice Verify mintAmounts matches the actual token transfers performed by mint() function testMintDepositAmountsMatchesMint_3TokenPool() public { // Use a range of LP requests (tiny to large fraction) uint256 totalLp = pool.totalSupply(); @@ -528,7 +528,7 @@ contract PartyPoolTest is Test { if (req == 0) req = 1; // Compute expected deposit amounts via view - uint256[] memory expected = pool.mintDepositAmounts(req); + uint256[] memory expected = pool.mintAmounts(req); // Ensure alice has tokens and approve pool vm.startPrank(alice); @@ -542,7 +542,7 @@ contract PartyPoolTest is Test { uint256 a2Before = token2.balanceOf(alice); // Perform mint (may revert for zero-request; ensure req>0 above) - // Guard: if mintDepositAmounts returned all zeros, skip (nothing to transfer) + // Guard: if mintAmounts returned all zeros, skip (nothing to transfer) bool allZero = (expected[0] == 0 && expected[1] == 0 && expected[2] == 0); if (!allZero) { uint256 lpBefore = pool.balanceOf(alice); @@ -561,7 +561,7 @@ contract PartyPoolTest is Test { } } - /// @notice Verify mintDepositAmounts matches the actual token transfers performed by mint() for 10-token pool + /// @notice Verify mintAmounts matches the actual token transfers performed by mint() for 10-token pool function testMintDepositAmountsMatchesMint_10TokenPool() public { uint256 totalLp = pool10.totalSupply(); uint256[] memory requests = new uint256[](4); @@ -573,7 +573,7 @@ contract PartyPoolTest is Test { uint256 req = requests[k]; if (req == 0) req = 1; - uint256[] memory expected = pool10.mintDepositAmounts(req); + uint256[] memory expected = pool10.mintAmounts(req); // Approve all tokens from alice vm.startPrank(alice); @@ -624,7 +624,7 @@ contract PartyPoolTest is Test { } } - /// @notice Verify burnReceiveAmounts matches actual transfers performed by burn() for 3-token pool + /// @notice Verify burnAmounts matches actual transfers performed by burn() for 3-token pool function testBurnReceiveAmountsMatchesBurn_3TokenPool() public { // Use address(this) as payer (holds initial LP from setUp) uint256 totalLp = pool.totalSupply(); @@ -651,7 +651,7 @@ contract PartyPoolTest is Test { } // Recompute withdraw amounts via view after any top-up - uint256[] memory expected = pool.burnReceiveAmounts(req); + uint256[] memory expected = pool.burnAmounts(req); // If expected withdraws are all zero (rounding edge), skip this iteration if (expected[0] == 0 && expected[1] == 0 && expected[2] == 0) { @@ -677,7 +677,7 @@ contract PartyPoolTest is Test { } } - /// @notice Verify burnReceiveAmounts matches actual transfers performed by burn() for 10-token pool + /// @notice Verify burnAmounts matches actual transfers performed by burn() for 10-token pool function testBurnReceiveAmountsMatchesBurn_10TokenPool() public { uint256 totalLp = pool10.totalSupply(); uint256[] memory burns = new uint256[](4); @@ -708,7 +708,7 @@ contract PartyPoolTest is Test { vm.stopPrank(); } - uint256[] memory expected = pool10.burnReceiveAmounts(req); + uint256[] memory expected = pool10.burnAmounts(req); // If expected withdraws are all zero (rounding edge), skip this iteration bool allZero = true; @@ -1361,8 +1361,8 @@ contract PartyPoolTest is Test { token2.approve(address(poolCustom), type(uint256).max); // Get required deposit amounts for both pools - uint256[] memory depositsDefault = poolDefault.mintDepositAmounts(lpRequestDefault); - uint256[] memory depositsCustom = poolCustom.mintDepositAmounts(lpRequestCustom); + uint256[] memory depositsDefault = poolDefault.mintAmounts(lpRequestDefault); + uint256[] memory depositsCustom = poolCustom.mintAmounts(lpRequestCustom); // Deposits should be identical (same proportion of identical balances) assertEq(depositsDefault[0], depositsCustom[0], "Token0 deposits should be identical");