diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index c649c09..558300a 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; +import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; import "./LMSRStabilized.sol"; import {IERC20Metadata} from "../lib/openzeppelin-contracts/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; @@ -135,12 +136,14 @@ interface IPartyPool is IERC20Metadata { /// @param maxAmountIn maximum gross input allowed (inclusive of fee) /// @param limitPrice maximum acceptable marginal price (pass 0 to ignore) /// @return amountIn gross input amount to transfer (includes fee), amountOut output amount user would receive, fee fee amount taken +/* function swapAmounts( uint256 inputTokenIndex, uint256 outputTokenIndex, uint256 maxAmountIn, int128 limitPrice ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee); +*/ /// @notice Swap input token inputTokenIndex -> token outputTokenIndex. Payer must approve token inputTokenIndex. /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. @@ -215,17 +218,17 @@ interface IPartyPool is IERC20Metadata { uint256 deadline ) external returns (uint256 amountOutUint); - /// @notice Receive token amounts and require them to be repaid plus a fee inside a callback. - /// @dev The caller must implement IPartyFlashCallback#partyFlashCallback which receives (amounts, repaymentAmounts, data). - /// This function verifies that, after the callback returns, the pool's balances have increased by at least the fees - /// for each borrowed token. Reverts if repayment (including fee) did not occur. - /// @param recipient The address which will receive the token amounts - /// @param amounts The amount of each token to send (array length must equal pool size) - /// @param data Any data to be passed through to the callback - function flash( - address recipient, - uint256[] memory amounts, + /** + * @dev Initiate a flash loan. + * @param receiver The receiver of the tokens in the loan, and the receiver of the callback. + * @param token The loan currency. + * @param amount The amount of tokens lent. + * @param data Arbitrary data structure, intended to contain user-defined parameters. + */ + function flashLoan( + IERC3156FlashBorrower receiver, + address token, + uint256 amount, bytes calldata data - ) external; - + ) external returns (bool); } diff --git a/src/PartyPool.sol b/src/PartyPool.sol index e7b9e79..34b79cb 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -1,20 +1,22 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; +import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; -import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; -import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; import {Address} from "../lib/openzeppelin-contracts/contracts/utils/Address.sol"; -import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol"; import {ERC20External} from "./ERC20External.sol"; +import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {IPartyFlashCallback} from "./IPartyFlashCallback.sol"; import {IPartyPool} from "./IPartyPool.sol"; -import {LMSRStabilized} from "./LMSRStabilized.sol"; import {LMSRStabilizedBalancedPair} from "./LMSRStabilizedBalancedPair.sol"; +import {LMSRStabilized} from "./LMSRStabilized.sol"; import {PartyPoolBase} from "./PartyPoolBase.sol"; import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol"; import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol"; import {Proxy} from "../lib/openzeppelin-contracts/contracts/proxy/Proxy.sol"; +import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol"; +import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; +import {IERC3156FlashLender} from "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashLender.sol"; /// @title PartyPool - LMSR-backed multi-asset pool with LP ERC20 token /// @notice A multi-asset liquidity pool backed by the LMSRStabilized pricing model. @@ -196,7 +198,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { Swaps ---------------------- */ - /// @inheritdoc IPartyPool +/* function swapAmounts( uint256 inputTokenIndex, uint256 outputTokenIndex, @@ -206,6 +208,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); return (grossIn, outUint, feeUint); } +*/ /// @inheritdoc IPartyPool function swap( @@ -295,7 +298,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { require(deltaInternalI > int128(0), "swap: input too small after fee"); // Compute internal amounts using LMSR (exact-input with price limit) - // if _stablePair is true, use the optimized path (amountInInternalUsed, amountOutInternal) = _swapAmountsForExactInput(inputTokenIndex, outputTokenIndex, deltaInternalI, limitPrice); // Convert actual used input internal -> uint (ceil) @@ -405,82 +407,38 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { } - /// @notice Receive token amounts and require them to be repaid plus a fee inside a callback. - /// @dev The caller must implement IPartyFlashCallback#partyFlashCallback which receives (amounts, repaymentAmounts, data). - /// This function verifies that, after the callback returns, the pool's balances have increased by at least the fees - /// for each borrowed token. Reverts if repayment (including fee) did not occur. - /// @param recipient The address which will receive the token amounts - /// @param amounts The amount of each token to send (array length must equal pool size) - /// @param data Any data to be passed through to the callback - // todo gas-efficient single-asset flash - // todo fix this func's gas - function flash( - address recipient, - uint256[] memory amounts, + bytes32 internal constant FLASH_CALLBACK_SUCCESS = keccak256("ERC3156FlashBorrower.onFlashLoan"); + + /** + * @dev Loan `amount` tokens to `receiver`, and takes it back plus a `flashFee` after the callback. + * @param receiver The contract receiving the tokens, needs to implement the `onFlashLoan(address user, uint256 amount, uint256 fee, bytes calldata)` interface. + * @param tokenAddr The loan currency. + * @param amount The amount of tokens lent. + * @param data A data parameter to be passed on to the `receiver` for any custom use. + */ + function flashLoan( + IERC3156FlashBorrower receiver, + address tokenAddr, + uint256 amount, bytes calldata data - ) external nonReentrant { - require(recipient != address(0), "flash: zero recipient"); - require(amounts.length == tokens.length, "flash: amounts length mismatch"); - - // Calculate repayment amounts for each token including fee - uint256[] memory repaymentAmounts = new uint256[](tokens.length); - - // Store initial balances to verify repayment later - uint256[] memory initialBalances = new uint256[](tokens.length); - - // Track if any token amount is non-zero - bool hasNonZeroAmount = false; - - // Process each token, skipping those with zero amounts - for (uint256 i = 0; i < tokens.length; i++) { - uint256 amount = amounts[i]; - - if (amount > 0) { - hasNonZeroAmount = true; - - // Calculate repayment amount with fee (ceiling) - repaymentAmounts[i] = amount + _ceilFee(amount, FLASH_FEE_PPM); - - // Record initial balance - initialBalances[i] = IERC20(tokens[i]).balanceOf(address(this)); - - // Transfer token to recipient - tokens[i].safeTransfer(recipient, amount); - } - } - - // Ensure at least one token is being borrowed - require(hasNonZeroAmount, "flash: no tokens requested"); - - // Call flash callback with expected repayment amounts - IPartyFlashCallback(msg.sender).partyFlashCallback(amounts, repaymentAmounts, data); - - // Verify repayment amounts for tokens that were borrowed - for (uint256 i = 0; i < tokens.length; i++) { - if (amounts[i] > 0) { - uint256 currentBalance = IERC20(tokens[i]).balanceOf(address(this)); - - // Compute expected fee (ceiling) - uint256 feeExpected = _ceilFee(amounts[i], FLASH_FEE_PPM); - - // Verify repayment: current balance must be at least (initial balance + fee) - require( - currentBalance >= initialBalances[i] + feeExpected, - "flash: repayment failed" - ); - - // Accrue protocol share (floor) of the flash fee - if (PROTOCOL_FEE_PPM > 0 && PROTOCOL_FEE_ADDRESS != address(0)) { - uint256 protoShare = (feeExpected * PROTOCOL_FEE_PPM) / 1_000_000; // floor - if (protoShare > 0) { - protocolFeesOwed[i] += protoShare; - } - } - - // Update cached balance to onchain minus owed - _recordCachedBalance(i, currentBalance); - } - } + ) external nonReentrant returns (bool) + { + IERC20 token = IERC20(tokenAddr); + require(amount <= token.balanceOf(address(this))); + (uint256 fee, ) = _computeFee(amount, FLASH_FEE_PPM); + require( + token.transfer(address(receiver), amount), + "FlashLender: Transfer failed" + ); + require( + receiver.onFlashLoan(msg.sender, address(token), amount, fee, data) == FLASH_CALLBACK_SUCCESS, + "FlashLender: Callback failed" + ); + require( + token.transferFrom(address(receiver), address(this), amount + fee), + "FlashLender: Repay failed" + ); + return true; } diff --git a/src/PartyPoolView.sol b/src/PartyPoolView.sol index d56da86..3d6998a 100644 --- a/src/PartyPoolView.sol +++ b/src/PartyPoolView.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.30; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; +import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {IPartyPool} from "./IPartyPool.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol"; import {PartyPoolHelpers} from "./PartyPoolHelpers.sol"; @@ -151,4 +152,31 @@ contract PartyPoolView is PartyPoolHelpers { } + /** + * @dev The amount of currency available to be lent. + * @param token The loan currency. + * @return The amount of `token` that can be borrowed. + */ + function maxFlashLoan( + IPartyPool pool, + address token + ) external view returns (uint256) { + return IERC20(token).balanceOf(address(pool)); + } + + /** + * @dev The fee to be charged for a given loan. + * @param token The loan currency. + * @param amount The amount of tokens lent. + * @return fee The amount of `token` to be charged for the loan, on top of the returned principal. + */ + function flashFee( + IPartyPool pool, + address token, + uint256 amount + ) external view returns (uint256 fee) { + (fee,) = _computeFee(amount, pool.flashFeePpm()); + } + + } diff --git a/test/GasTest.sol b/test/GasTest.sol index c47fcb0..0a981eb 100644 --- a/test/GasTest.sol +++ b/test/GasTest.sol @@ -8,104 +8,74 @@ import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; import "../src/LMSRStabilized.sol"; import "../src/PartyPool.sol"; import "../src/PartyPlanner.sol"; -import "../src/IPartyFlashCallback.sol"; +import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; import {Deploy} from "../src/Deploy.sol"; /// @notice Test contract that implements the flash callback for testing flash loans -contract FlashBorrower is IPartyFlashCallback { +contract FlashBorrower is IERC3156FlashBorrower { enum Action { NORMAL, // Normal repayment REPAY_NONE, // Don't repay anything REPAY_PARTIAL, // Repay less than required REPAY_NO_FEE, // Repay only the principal without fee - REPAY_EXACT, // Repay exactly the required amount - REPAY_EXTRA // Repay more than required (donation) + REPAY_EXACT // Repay exactly the required amount } Action public action; address public pool; - address public recipient; - address[] public tokens; + address public payer; - constructor(address _pool, IERC20[] memory _tokens) { + constructor(address _pool) { pool = _pool; - tokens = new address[](_tokens.length); - for (uint i = 0; i < _tokens.length; i++) { - tokens[i] = address(_tokens[i]); - } } - function setAction(Action _action, address _recipient) external { + function setAction(Action _action, address _payer) external { action = _action; - recipient = _recipient; + payer = _payer; } - function flash(uint256[] memory amounts) external { - PartyPool(pool).flash(recipient, amounts, ""); + function flash(address token, uint256 amount) external { + PartyPool(pool).flashLoan(IERC3156FlashBorrower(address(this)), token, amount, ""); } - function partyFlashCallback( - uint256[] memory loanAmounts, - uint256[] memory repaymentAmounts, + function onFlashLoan( + address initiator, + address token, + uint256 amount, + uint256 fee, bytes calldata /* data */ - ) external override { + ) external override returns (bytes32) { require(msg.sender == pool, "Callback not called by pool"); - if (action == Action.NORMAL || action == Action.REPAY_EXTRA) { - // Normal or extra repayment - transfer required amounts back to pool - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - uint256 repaymentAmount = repaymentAmounts[i]; + if (action == Action.NORMAL) { + // Normal repayment + // We received 'amount' from the pool, need to pay back amount + fee + uint256 repaymentAmount = amount + fee; - // For REPAY_EXTRA, add 1 to each repayment - if (action == Action.REPAY_EXTRA) { - repaymentAmount += 1; - } + // Transfer the fee from payer to this contract + // (we already have the principal 'amount' from the flash loan) + TestERC20(token).transferFrom(payer, address(this), fee); - // Transfer from recipient back to pool - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - repaymentAmount - ); - } - } + // Approve pool to pull back the full repayment + TestERC20(token).approve(pool, repaymentAmount); } else if (action == Action.REPAY_PARTIAL) { - // Repay half of the required amounts - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - uint256 partialRepayment = repaymentAmounts[i] / 2; - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - partialRepayment - ); - } - } + // Repay half of the required amount + uint256 partialRepayment = (amount + fee) / 2; + TestERC20(token).approve(pool, partialRepayment); } else if (action == Action.REPAY_NO_FEE) { - // Repay only the principal without fee - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - loanAmounts[i] - ); - } - } + // Repay only the principal without fee (we already have it from the loan) + TestERC20(token).approve(pool, amount); } else if (action == Action.REPAY_EXACT) { // Repay exactly what was required - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - repaymentAmounts[i] - ); - } - } + uint256 repaymentAmount = amount + fee; + // Transfer the fee from payer (we have the principal from the loan) + TestERC20(token).transferFrom(payer, address(this), fee); + // Approve pool to pull back the full repayment + TestERC20(token).approve(pool, repaymentAmount); } - // For REPAY_NONE, do nothing (don't repay) + // For REPAY_NONE, do nothing (don't approve repayment) + + return keccak256("ERC3156FlashBorrower.onFlashLoan"); } } @@ -247,13 +217,11 @@ contract GasTest is Test { /// @notice Setup a flash borrower for testing function setupFlashBorrower() internal returns (FlashBorrower borrower) { - // Get token addresses from the 2-token pool - IERC20[] memory tokenAddresses = pool2.allTokens(); - // Deploy the borrower contract - borrower = new FlashBorrower(address(pool2), tokenAddresses); + borrower = new FlashBorrower(address(pool2)); // Mint tokens to alice to be used for repayments and approve borrower + IERC20[] memory tokenAddresses = pool2.allTokens(); vm.startPrank(alice); for (uint256 i = 0; i < tokenAddresses.length; i++) { TestERC20(address(tokenAddresses[i])).mint(alice, INIT_BAL * 2); @@ -441,33 +409,14 @@ contract GasTest is Test { // Configure borrower borrower.setAction(FlashBorrower.Action.NORMAL, alice); - // Create loan request for single token (get array size from pool) + // Get first token from pool IERC20[] memory poolTokens = pool2.allTokens(); - uint256[] memory amounts = new uint256[](poolTokens.length); - amounts[0] = 1000; + address token = address(poolTokens[0]); + uint256 amount = 1000; // Execute flash loan 10 times to measure gas for (uint256 i = 0; i < 10; i++) { - borrower.flash(amounts); - } - } - - /// @notice Gas measurement: flash with multiple tokens - function testFlashGasMultipleTokens() public { - FlashBorrower borrower = setupFlashBorrower(); - - // Configure borrower - borrower.setAction(FlashBorrower.Action.NORMAL, alice); - - // Create loan request for multiple tokens (get array size from pool) - IERC20[] memory poolTokens = pool2.allTokens(); - uint256[] memory amounts = new uint256[](poolTokens.length); - amounts[0] = 1000; - amounts[1] = 2000; - - // Execute flash loan 10 times to measure gas - for (uint256 i = 0; i < 10; i++) { - borrower.flash(amounts); + borrower.flash(token, amount); } } } diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index a8da3d5..8aa08e4 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -9,103 +9,76 @@ import "../src/LMSRStabilized.sol"; import "../src/PartyPool.sol"; // Import the flash callback interface -import "../src/IPartyFlashCallback.sol"; +import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; import {PartyPlanner} from "../src/PartyPlanner.sol"; import {Deploy} from "../src/Deploy.sol"; import {PartyPoolView} from "../src/PartyPoolView.sol"; /// @notice Test contract that implements the flash callback for testing flash loans -contract FlashBorrower is IPartyFlashCallback { +contract FlashBorrower is IERC3156FlashBorrower { enum Action { NORMAL, // Normal repayment REPAY_NONE, // Don't repay anything REPAY_PARTIAL, // Repay less than required REPAY_NO_FEE, // Repay only the principal without fee - REPAY_EXACT, // Repay exactly the required amount - REPAY_EXTRA // Repay more than required (donation) + REPAY_EXACT // Repay exactly the required amount } Action public action; address public pool; - address public recipient; - address[] public tokens; + address public payer; - constructor(address _pool, address[] memory _tokens) { + constructor(address _pool) { pool = _pool; - tokens = _tokens; } - function setAction(Action _action, address _recipient) external { + function setAction(Action _action, address _payer) external { action = _action; - recipient = _recipient; + payer = _payer; } - function flash(uint256[] memory amounts) external { - PartyPool(pool).flash(recipient, amounts, ""); + function flash(address token, uint256 amount) external { + PartyPool(pool).flashLoan(IERC3156FlashBorrower(address(this)), token, amount, ""); } - function partyFlashCallback( - uint256[] memory loanAmounts, - uint256[] memory repaymentAmounts, + function onFlashLoan( + address initiator, + address token, + uint256 amount, + uint256 fee, bytes calldata /* data */ - ) external override { + ) external override returns (bytes32) { require(msg.sender == pool, "Callback not called by pool"); - if (action == Action.NORMAL || action == Action.REPAY_EXTRA) { - // Normal or extra repayment - transfer required amounts back to pool - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - uint256 repaymentAmount = repaymentAmounts[i]; + if (action == Action.NORMAL) { + // Normal repayment + // We received 'amount' from the pool, need to pay back amount + fee + uint256 repaymentAmount = amount + fee; - // For REPAY_EXTRA, add 1 to each repayment - if (action == Action.REPAY_EXTRA) { - repaymentAmount += 1; - } + // Transfer the fee from payer to this contract + // (we already have the principal 'amount' from the flash loan) + TestERC20(token).transferFrom(payer, address(this), fee); - // Transfer from recipient back to pool - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - repaymentAmount - ); - } - } + // Approve pool to pull back the full repayment + TestERC20(token).approve(pool, repaymentAmount); } else if (action == Action.REPAY_PARTIAL) { - // Repay half of the required amounts - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - uint256 partialRepayment = repaymentAmounts[i] / 2; - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - partialRepayment - ); - } - } + // Repay half of the required amount + uint256 partialRepayment = (amount + fee) / 2; + TestERC20(token).approve(pool, partialRepayment); } else if (action == Action.REPAY_NO_FEE) { - // Repay only the principal without fee - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - loanAmounts[i] - ); - } - } + // Repay only the principal without fee (we already have it from the loan) + TestERC20(token).approve(pool, amount); } else if (action == Action.REPAY_EXACT) { // Repay exactly what was required - for (uint256 i = 0; i < loanAmounts.length; i++) { - if (loanAmounts[i] > 0) { - TestERC20(tokens[i]).transferFrom( - recipient, - pool, - repaymentAmounts[i] - ); - } - } + uint256 repaymentAmount = amount + fee; + // Transfer the fee from payer (we have the principal from the loan) + TestERC20(token).transferFrom(payer, address(this), fee); + // Approve pool to pull back the full repayment + TestERC20(token).approve(pool, repaymentAmount); } - // For REPAY_NONE, do nothing (don't repay) + // For REPAY_NONE, do nothing (don't approve repayment) + + return keccak256("ERC3156FlashBorrower.onFlashLoan"); } } @@ -846,14 +819,8 @@ contract PartyPoolTest is Test { /// @notice Setup a flash borrower for testing function setupFlashBorrower() internal returns (FlashBorrower borrower) { - // Create array of token addresses for borrower - address[] memory tokenAddresses = new address[](3); - tokenAddresses[0] = address(token0); - tokenAddresses[1] = address(token1); - tokenAddresses[2] = address(token2); - // Deploy the borrower contract - borrower = new FlashBorrower(address(pool), tokenAddresses); + borrower = new FlashBorrower(address(pool)); // Mint tokens to alice to be used for repayments token0.mint(alice, INIT_BAL * 2); @@ -876,18 +843,17 @@ contract PartyPoolTest is Test { borrower.setAction(FlashBorrower.Action.NORMAL, alice); // Create loan request for token0 only - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; // Only borrow token0 + uint256 amount = 1000; // Record balances before flash uint256 aliceToken0Before = token0.balanceOf(alice); uint256 poolToken0Before = token0.balanceOf(address(pool)); // Execute flash loan - borrower.flash(amounts); + borrower.flash(address(token0), amount); // Net change for alice should equal the flash fee (principal is returned during repayment) - uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 fee = (amount * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation uint256 expectedAliceDecrease = fee; assertEq( aliceToken0Before - token0.balanceOf(alice), @@ -903,126 +869,6 @@ contract PartyPoolTest is Test { ); } - /// @notice Test flash loan with multiple tokens - function testFlashLoanMultipleTokens() public { - FlashBorrower borrower = setupFlashBorrower(); - - // Configure borrower to repay normally - borrower.setAction(FlashBorrower.Action.NORMAL, alice); - - // Create loan request for all tokens - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; - amounts[1] = 2000; - amounts[2] = 3000; - - // Record balances before flash - uint256[] memory aliceBalancesBefore = new uint256[](3); - uint256[] memory poolBalancesBefore = new uint256[](3); - - aliceBalancesBefore[0] = token0.balanceOf(alice); - aliceBalancesBefore[1] = token1.balanceOf(alice); - aliceBalancesBefore[2] = token2.balanceOf(alice); - - poolBalancesBefore[0] = token0.balanceOf(address(pool)); - poolBalancesBefore[1] = token1.balanceOf(address(pool)); - poolBalancesBefore[2] = token2.balanceOf(address(pool)); - - // Execute flash loan - borrower.flash(amounts); - - // Check balances for each token - for (uint256 i = 0; i < 3; i++) { - uint256 fee = (amounts[i] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation - uint256 expectedAliceDecrease = fee; - - IERC20 token; - if (i == 0) token = token0; - else if (i == 1) token = token1; - else token = token2; - - // Net change for Alice should equal the flash fee for this token (principal was returned) - assertEq( - aliceBalancesBefore[i] - token.balanceOf(alice), - expectedAliceDecrease, - "Alice should pay flash fee for token" - ); - - // Pool's balance increased by fee - assertEq( - token.balanceOf(address(pool)), - poolBalancesBefore[i] + fee, - "Pool should receive fee for token" - ); - } - } - - /// @notice Test flash loan with some zero amounts (should be skipped) - function testFlashLoanWithZeroAmounts() public { - FlashBorrower borrower = setupFlashBorrower(); - - // Configure borrower to repay normally - borrower.setAction(FlashBorrower.Action.NORMAL, alice); - - // Create loan request with mix of zero and non-zero amounts - uint256[] memory amounts = new uint256[](3); - amounts[0] = 0; // Zero - should be skipped - amounts[1] = 2000; // Non-zero - amounts[2] = 0; // Zero - should be skipped - - // Record balances before flash - uint256 aliceToken1Before = token1.balanceOf(alice); - uint256 poolToken1Before = token1.balanceOf(address(pool)); - - // Tokens that should remain unchanged - uint256 aliceToken0Before = token0.balanceOf(alice); - uint256 aliceToken2Before = token2.balanceOf(alice); - uint256 poolToken0Before = token0.balanceOf(address(pool)); - uint256 poolToken2Before = token2.balanceOf(address(pool)); - - // Execute flash loan - borrower.flash(amounts); - - // Check token1 balances changed appropriately - uint256 fee = (amounts[1] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation - uint256 expectedAliceDecrease = fee; - - assertEq( - aliceToken1Before - token1.balanceOf(alice), - expectedAliceDecrease, - "Alice should pay flash fee for token1" - ); - - assertEq( - token1.balanceOf(address(pool)), - poolToken1Before + fee, - "Pool should receive fee for token1" - ); - - // Check token0 and token2 balances remained unchanged - assertEq(token0.balanceOf(alice), aliceToken0Before, "Alice token0 balance should be unchanged"); - assertEq(token2.balanceOf(alice), aliceToken2Before, "Alice token2 balance should be unchanged"); - assertEq(token0.balanceOf(address(pool)), poolToken0Before, "Pool token0 balance should be unchanged"); - assertEq(token2.balanceOf(address(pool)), poolToken2Before, "Pool token2 balance should be unchanged"); - } - - /// @notice Test that flash reverts when all amounts are zero - function testFlashLoanAllZeroAmountsReverts() public { - FlashBorrower borrower = setupFlashBorrower(); - - // Configure borrower to repay normally - borrower.setAction(FlashBorrower.Action.NORMAL, alice); - - // Create loan request with all zeros - uint256[] memory amounts = new uint256[](3); - amounts[0] = 0; - amounts[1] = 0; - amounts[2] = 0; - - // Execute flash loan - should revert - vm.expectRevert(bytes("flash: no tokens requested")); - borrower.flash(amounts); - } /// @notice Test flash loan with incorrect repayment (none) function testFlashLoanNoRepaymentReverts() public { @@ -1032,12 +878,11 @@ contract PartyPoolTest is Test { borrower.setAction(FlashBorrower.Action.REPAY_NONE, alice); // Create loan request - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; + uint256 amount = 1000; - // Execute flash loan - should revert on validation - vm.expectRevert(bytes("flash: repayment failed")); - borrower.flash(amounts); + // Execute flash loan - should revert due to insufficient allowance when pool tries to pull repayment + vm.expectRevert(); + borrower.flash(address(token0), amount); } /// @notice Test flash loan with partial repayment (should revert) @@ -1048,12 +893,11 @@ contract PartyPoolTest is Test { borrower.setAction(FlashBorrower.Action.REPAY_PARTIAL, alice); // Create loan request - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; + uint256 amount = 1000; - // Execute flash loan - should revert on validation - vm.expectRevert(bytes("flash: repayment failed")); - borrower.flash(amounts); + // Execute flash loan - should revert due to insufficient allowance when pool tries to pull full repayment + vm.expectRevert(); + borrower.flash(address(token0), amount); } /// @notice Test flash loan with principal repayment but no fee (should revert) @@ -1064,16 +908,15 @@ contract PartyPoolTest is Test { borrower.setAction(FlashBorrower.Action.REPAY_NO_FEE, alice); // Create loan request - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; + uint256 amount = 1000; - // Execute flash loan - should revert on validation if fee > 0 + // Execute flash loan - should revert due to insufficient allowance if fee > 0 if (pool.flashFeePpm() > 0) { - vm.expectRevert(bytes("flash: repayment failed")); - borrower.flash(amounts); + vm.expectRevert(); + borrower.flash(address(token0), amount); } else { // If fee is zero, this should succeed - borrower.flash(amounts); + borrower.flash(address(token0), amount); } } @@ -1085,18 +928,17 @@ contract PartyPoolTest is Test { borrower.setAction(FlashBorrower.Action.REPAY_EXACT, alice); // Create loan request - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; + uint256 amount = 1000; // Record balances before flash uint256 aliceToken0Before = token0.balanceOf(alice); uint256 poolToken0Before = token0.balanceOf(address(pool)); // Execute flash loan - borrower.flash(amounts); + borrower.flash(address(token0), amount); // Check balances: net change for alice should equal the fee - uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 fee = (amount * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation uint256 expectedAliceDecrease = fee; assertEq( @@ -1112,115 +954,29 @@ contract PartyPoolTest is Test { ); } - /// @notice Test flash loan with extra repayment (donation, should succeed) - function testFlashLoanExtraRepayment() public { - FlashBorrower borrower = setupFlashBorrower(); + /// @notice Test flashFee view function matches flash implementation + function testFlashFee() public view { + // Test different loan amounts + uint256[] memory testAmounts = new uint256[](3); + testAmounts[0] = 1000; + testAmounts[1] = 2000; + testAmounts[2] = 3000; - // Configure borrower to repay more than required - borrower.setAction(FlashBorrower.Action.REPAY_EXTRA, alice); + for (uint256 i = 0; i < testAmounts.length; i++) { + uint256 amount = testAmounts[i]; + uint256 fee = viewer.flashFee(pool, address(token0), amount); - // Create loan request - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; + // Calculate expected fee + uint256 expectedFee = (amount * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceiling - // Record balances before flash - uint256 aliceToken0Before = token0.balanceOf(alice); - uint256 poolToken0Before = token0.balanceOf(address(pool)); - - // Execute flash loan - borrower.flash(amounts); - - // Check balances - net change for alice should equal fee + extra donation (principal returned) - uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation - uint256 extra = 1; // borrower donates +1 per token in REPAY_EXTRA - uint256 expectedAliceDecrease = fee + extra; // fee plus donation - - assertEq( - aliceToken0Before - token0.balanceOf(alice), - expectedAliceDecrease, - "Alice should pay fee + extra" - ); - - assertEq( - token0.balanceOf(address(pool)), - poolToken0Before + fee + extra, - "Pool should receive fee + extra" - ); - } - - /// @notice Test flashRepaymentAmounts matches flash implementation - function testFlashRepaymentAmounts() public view { - // Create different loan amount scenarios - uint256[][] memory testCases = new uint256[][](3); - - // Case 1: Single token - testCases[0] = new uint256[](3); - testCases[0][0] = 1000; - testCases[0][1] = 0; - testCases[0][2] = 0; - - // Case 2: Multiple tokens - testCases[1] = new uint256[](3); - testCases[1][0] = 1000; - testCases[1][1] = 2000; - testCases[1][2] = 3000; - - // Case 3: Mix of zero and non-zero - testCases[2] = new uint256[](3); - testCases[2][0] = 0; - testCases[2][1] = 2000; - testCases[2][2] = 0; - - for (uint256 i = 0; i < testCases.length; i++) { - uint256[] memory loanAmounts = testCases[i]; - uint256[] memory repaymentAmounts = viewer.flashRepaymentAmounts(pool, loanAmounts); - - // Verify each repayment amount is correctly calculated - for (uint256 j = 0; j < loanAmounts.length; j++) { - if (loanAmounts[j] == 0) { - // Zero loans should have zero repayment - assertEq(repaymentAmounts[j], 0, "Zero loan should have zero repayment"); - } else { - // Calculate expected repayment with fee - uint256 fee = (loanAmounts[j] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceiling - uint256 expectedRepayment = loanAmounts[j] + fee; - - assertEq( - repaymentAmounts[j], - expectedRepayment, - "Repayment calculation mismatch" - ); - } - } + assertEq( + fee, + expectedFee, + "Flash fee calculation mismatch" + ); } } - /// @notice Test flash with invalid recipient - function testFlashWithZeroRecipientReverts() public { - FlashBorrower borrower = setupFlashBorrower(); - - // Configure borrower with zero recipient - borrower.setAction(FlashBorrower.Action.NORMAL, address(0)); - - // Create loan request - uint256[] memory amounts = new uint256[](3); - amounts[0] = 1000; - - // Execute flash loan - should revert due to zero recipient - vm.expectRevert(bytes("flash: zero recipient")); - borrower.flash(amounts); - } - - /// @notice Test flash with incorrect amounts length - function testFlashWithIncorrectLengthReverts() public { - // Call flash directly with incorrect length - uint256[] memory wrongLengthAmounts = new uint256[](2); // Pool has 3 tokens - wrongLengthAmounts[0] = 1000; - wrongLengthAmounts[1] = 2000; - - vm.expectRevert(bytes("flash: amounts length mismatch")); - pool.flash(alice, wrongLengthAmounts, ""); - } /// @notice Test that passing nonzero lpTokens to initialMint doesn't affect swap results /// compared to pools initialized with default lpTokens (0)