flashLoan rewritten as ERC-3156

This commit is contained in:
tim
2025-10-07 14:12:27 -04:00
parent 677ce4886c
commit ef039aa57e
5 changed files with 199 additions and 505 deletions

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: UNLICENSED // SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.30; pragma solidity ^0.8.30;
import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol";
import "./LMSRStabilized.sol"; import "./LMSRStabilized.sol";
import {IERC20Metadata} from "../lib/openzeppelin-contracts/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import {IERC20Metadata} from "../lib/openzeppelin-contracts/contracts/token/ERC20/extensions/IERC20Metadata.sol";
import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol";
@@ -135,12 +136,14 @@ interface IPartyPool is IERC20Metadata {
/// @param maxAmountIn maximum gross input allowed (inclusive of fee) /// @param maxAmountIn maximum gross input allowed (inclusive of fee)
/// @param limitPrice maximum acceptable marginal price (pass 0 to ignore) /// @param limitPrice maximum acceptable marginal price (pass 0 to ignore)
/// @return amountIn gross input amount to transfer (includes fee), amountOut output amount user would receive, fee fee amount taken /// @return amountIn gross input amount to transfer (includes fee), amountOut output amount user would receive, fee fee amount taken
/*
function swapAmounts( function swapAmounts(
uint256 inputTokenIndex, uint256 inputTokenIndex,
uint256 outputTokenIndex, uint256 outputTokenIndex,
uint256 maxAmountIn, uint256 maxAmountIn,
int128 limitPrice int128 limitPrice
) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee); ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee);
*/
/// @notice Swap input token inputTokenIndex -> token outputTokenIndex. Payer must approve token inputTokenIndex. /// @notice Swap input token inputTokenIndex -> token outputTokenIndex. Payer must approve token inputTokenIndex.
/// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver.
@@ -215,17 +218,17 @@ interface IPartyPool is IERC20Metadata {
uint256 deadline uint256 deadline
) external returns (uint256 amountOutUint); ) external returns (uint256 amountOutUint);
/// @notice Receive token amounts and require them to be repaid plus a fee inside a callback. /**
/// @dev The caller must implement IPartyFlashCallback#partyFlashCallback which receives (amounts, repaymentAmounts, data). * @dev Initiate a flash loan.
/// This function verifies that, after the callback returns, the pool's balances have increased by at least the fees * @param receiver The receiver of the tokens in the loan, and the receiver of the callback.
/// for each borrowed token. Reverts if repayment (including fee) did not occur. * @param token The loan currency.
/// @param recipient The address which will receive the token amounts * @param amount The amount of tokens lent.
/// @param amounts The amount of each token to send (array length must equal pool size) * @param data Arbitrary data structure, intended to contain user-defined parameters.
/// @param data Any data to be passed through to the callback */
function flash( function flashLoan(
address recipient, IERC3156FlashBorrower receiver,
uint256[] memory amounts, address token,
uint256 amount,
bytes calldata data bytes calldata data
) external; ) external returns (bool);
} }

View File

@@ -1,20 +1,22 @@
// SPDX-License-Identifier: UNLICENSED // SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.30; pragma solidity ^0.8.30;
import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol";
import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol";
import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol";
import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol";
import {Address} from "../lib/openzeppelin-contracts/contracts/utils/Address.sol"; import {Address} from "../lib/openzeppelin-contracts/contracts/utils/Address.sol";
import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol";
import {ERC20External} from "./ERC20External.sol"; import {ERC20External} from "./ERC20External.sol";
import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol";
import {IPartyFlashCallback} from "./IPartyFlashCallback.sol"; import {IPartyFlashCallback} from "./IPartyFlashCallback.sol";
import {IPartyPool} from "./IPartyPool.sol"; import {IPartyPool} from "./IPartyPool.sol";
import {LMSRStabilized} from "./LMSRStabilized.sol";
import {LMSRStabilizedBalancedPair} from "./LMSRStabilizedBalancedPair.sol"; import {LMSRStabilizedBalancedPair} from "./LMSRStabilizedBalancedPair.sol";
import {LMSRStabilized} from "./LMSRStabilized.sol";
import {PartyPoolBase} from "./PartyPoolBase.sol"; import {PartyPoolBase} from "./PartyPoolBase.sol";
import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol"; import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol";
import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol"; import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol";
import {Proxy} from "../lib/openzeppelin-contracts/contracts/proxy/Proxy.sol"; import {Proxy} from "../lib/openzeppelin-contracts/contracts/proxy/Proxy.sol";
import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol";
import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol";
import {IERC3156FlashLender} from "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashLender.sol";
/// @title PartyPool - LMSR-backed multi-asset pool with LP ERC20 token /// @title PartyPool - LMSR-backed multi-asset pool with LP ERC20 token
/// @notice A multi-asset liquidity pool backed by the LMSRStabilized pricing model. /// @notice A multi-asset liquidity pool backed by the LMSRStabilized pricing model.
@@ -196,7 +198,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool {
Swaps Swaps
---------------------- */ ---------------------- */
/// @inheritdoc IPartyPool /*
function swapAmounts( function swapAmounts(
uint256 inputTokenIndex, uint256 inputTokenIndex,
uint256 outputTokenIndex, uint256 outputTokenIndex,
@@ -206,6 +208,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool {
(uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice);
return (grossIn, outUint, feeUint); return (grossIn, outUint, feeUint);
} }
*/
/// @inheritdoc IPartyPool /// @inheritdoc IPartyPool
function swap( function swap(
@@ -295,7 +298,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool {
require(deltaInternalI > int128(0), "swap: input too small after fee"); require(deltaInternalI > int128(0), "swap: input too small after fee");
// Compute internal amounts using LMSR (exact-input with price limit) // Compute internal amounts using LMSR (exact-input with price limit)
// if _stablePair is true, use the optimized path
(amountInInternalUsed, amountOutInternal) = _swapAmountsForExactInput(inputTokenIndex, outputTokenIndex, deltaInternalI, limitPrice); (amountInInternalUsed, amountOutInternal) = _swapAmountsForExactInput(inputTokenIndex, outputTokenIndex, deltaInternalI, limitPrice);
// Convert actual used input internal -> uint (ceil) // Convert actual used input internal -> uint (ceil)
@@ -405,82 +407,38 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool {
} }
/// @notice Receive token amounts and require them to be repaid plus a fee inside a callback. bytes32 internal constant FLASH_CALLBACK_SUCCESS = keccak256("ERC3156FlashBorrower.onFlashLoan");
/// @dev The caller must implement IPartyFlashCallback#partyFlashCallback which receives (amounts, repaymentAmounts, data).
/// This function verifies that, after the callback returns, the pool's balances have increased by at least the fees /**
/// for each borrowed token. Reverts if repayment (including fee) did not occur. * @dev Loan `amount` tokens to `receiver`, and takes it back plus a `flashFee` after the callback.
/// @param recipient The address which will receive the token amounts * @param receiver The contract receiving the tokens, needs to implement the `onFlashLoan(address user, uint256 amount, uint256 fee, bytes calldata)` interface.
/// @param amounts The amount of each token to send (array length must equal pool size) * @param tokenAddr The loan currency.
/// @param data Any data to be passed through to the callback * @param amount The amount of tokens lent.
// todo gas-efficient single-asset flash * @param data A data parameter to be passed on to the `receiver` for any custom use.
// todo fix this func's gas */
function flash( function flashLoan(
address recipient, IERC3156FlashBorrower receiver,
uint256[] memory amounts, address tokenAddr,
uint256 amount,
bytes calldata data bytes calldata data
) external nonReentrant { ) external nonReentrant returns (bool)
require(recipient != address(0), "flash: zero recipient"); {
require(amounts.length == tokens.length, "flash: amounts length mismatch"); IERC20 token = IERC20(tokenAddr);
require(amount <= token.balanceOf(address(this)));
// Calculate repayment amounts for each token including fee (uint256 fee, ) = _computeFee(amount, FLASH_FEE_PPM);
uint256[] memory repaymentAmounts = new uint256[](tokens.length); require(
token.transfer(address(receiver), amount),
// Store initial balances to verify repayment later "FlashLender: Transfer failed"
uint256[] memory initialBalances = new uint256[](tokens.length); );
require(
// Track if any token amount is non-zero receiver.onFlashLoan(msg.sender, address(token), amount, fee, data) == FLASH_CALLBACK_SUCCESS,
bool hasNonZeroAmount = false; "FlashLender: Callback failed"
);
// Process each token, skipping those with zero amounts require(
for (uint256 i = 0; i < tokens.length; i++) { token.transferFrom(address(receiver), address(this), amount + fee),
uint256 amount = amounts[i]; "FlashLender: Repay failed"
);
if (amount > 0) { return true;
hasNonZeroAmount = true;
// Calculate repayment amount with fee (ceiling)
repaymentAmounts[i] = amount + _ceilFee(amount, FLASH_FEE_PPM);
// Record initial balance
initialBalances[i] = IERC20(tokens[i]).balanceOf(address(this));
// Transfer token to recipient
tokens[i].safeTransfer(recipient, amount);
}
}
// Ensure at least one token is being borrowed
require(hasNonZeroAmount, "flash: no tokens requested");
// Call flash callback with expected repayment amounts
IPartyFlashCallback(msg.sender).partyFlashCallback(amounts, repaymentAmounts, data);
// Verify repayment amounts for tokens that were borrowed
for (uint256 i = 0; i < tokens.length; i++) {
if (amounts[i] > 0) {
uint256 currentBalance = IERC20(tokens[i]).balanceOf(address(this));
// Compute expected fee (ceiling)
uint256 feeExpected = _ceilFee(amounts[i], FLASH_FEE_PPM);
// Verify repayment: current balance must be at least (initial balance + fee)
require(
currentBalance >= initialBalances[i] + feeExpected,
"flash: repayment failed"
);
// Accrue protocol share (floor) of the flash fee
if (PROTOCOL_FEE_PPM > 0 && PROTOCOL_FEE_ADDRESS != address(0)) {
uint256 protoShare = (feeExpected * PROTOCOL_FEE_PPM) / 1_000_000; // floor
if (protoShare > 0) {
protocolFeesOwed[i] += protoShare;
}
}
// Update cached balance to onchain minus owed
_recordCachedBalance(i, currentBalance);
}
}
} }

View File

@@ -2,6 +2,7 @@
pragma solidity ^0.8.30; pragma solidity ^0.8.30;
import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol";
import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol";
import {IPartyPool} from "./IPartyPool.sol"; import {IPartyPool} from "./IPartyPool.sol";
import {LMSRStabilized} from "./LMSRStabilized.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol";
import {PartyPoolHelpers} from "./PartyPoolHelpers.sol"; import {PartyPoolHelpers} from "./PartyPoolHelpers.sol";
@@ -151,4 +152,31 @@ contract PartyPoolView is PartyPoolHelpers {
} }
/**
* @dev The amount of currency available to be lent.
* @param token The loan currency.
* @return The amount of `token` that can be borrowed.
*/
function maxFlashLoan(
IPartyPool pool,
address token
) external view returns (uint256) {
return IERC20(token).balanceOf(address(pool));
}
/**
* @dev The fee to be charged for a given loan.
* @param token The loan currency.
* @param amount The amount of tokens lent.
* @return fee The amount of `token` to be charged for the loan, on top of the returned principal.
*/
function flashFee(
IPartyPool pool,
address token,
uint256 amount
) external view returns (uint256 fee) {
(fee,) = _computeFee(amount, pool.flashFeePpm());
}
} }

View File

@@ -8,104 +8,74 @@ import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import "../src/LMSRStabilized.sol"; import "../src/LMSRStabilized.sol";
import "../src/PartyPool.sol"; import "../src/PartyPool.sol";
import "../src/PartyPlanner.sol"; import "../src/PartyPlanner.sol";
import "../src/IPartyFlashCallback.sol"; import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol";
import {Deploy} from "../src/Deploy.sol"; import {Deploy} from "../src/Deploy.sol";
/// @notice Test contract that implements the flash callback for testing flash loans /// @notice Test contract that implements the flash callback for testing flash loans
contract FlashBorrower is IPartyFlashCallback { contract FlashBorrower is IERC3156FlashBorrower {
enum Action { enum Action {
NORMAL, // Normal repayment NORMAL, // Normal repayment
REPAY_NONE, // Don't repay anything REPAY_NONE, // Don't repay anything
REPAY_PARTIAL, // Repay less than required REPAY_PARTIAL, // Repay less than required
REPAY_NO_FEE, // Repay only the principal without fee REPAY_NO_FEE, // Repay only the principal without fee
REPAY_EXACT, // Repay exactly the required amount REPAY_EXACT // Repay exactly the required amount
REPAY_EXTRA // Repay more than required (donation)
} }
Action public action; Action public action;
address public pool; address public pool;
address public recipient; address public payer;
address[] public tokens;
constructor(address _pool, IERC20[] memory _tokens) { constructor(address _pool) {
pool = _pool; pool = _pool;
tokens = new address[](_tokens.length);
for (uint i = 0; i < _tokens.length; i++) {
tokens[i] = address(_tokens[i]);
}
} }
function setAction(Action _action, address _recipient) external { function setAction(Action _action, address _payer) external {
action = _action; action = _action;
recipient = _recipient; payer = _payer;
} }
function flash(uint256[] memory amounts) external { function flash(address token, uint256 amount) external {
PartyPool(pool).flash(recipient, amounts, ""); PartyPool(pool).flashLoan(IERC3156FlashBorrower(address(this)), token, amount, "");
} }
function partyFlashCallback( function onFlashLoan(
uint256[] memory loanAmounts, address initiator,
uint256[] memory repaymentAmounts, address token,
uint256 amount,
uint256 fee,
bytes calldata /* data */ bytes calldata /* data */
) external override { ) external override returns (bytes32) {
require(msg.sender == pool, "Callback not called by pool"); require(msg.sender == pool, "Callback not called by pool");
if (action == Action.NORMAL || action == Action.REPAY_EXTRA) { if (action == Action.NORMAL) {
// Normal or extra repayment - transfer required amounts back to pool // Normal repayment
for (uint256 i = 0; i < loanAmounts.length; i++) { // We received 'amount' from the pool, need to pay back amount + fee
if (loanAmounts[i] > 0) { uint256 repaymentAmount = amount + fee;
uint256 repaymentAmount = repaymentAmounts[i];
// For REPAY_EXTRA, add 1 to each repayment // Transfer the fee from payer to this contract
if (action == Action.REPAY_EXTRA) { // (we already have the principal 'amount' from the flash loan)
repaymentAmount += 1; TestERC20(token).transferFrom(payer, address(this), fee);
}
// Transfer from recipient back to pool // Approve pool to pull back the full repayment
TestERC20(tokens[i]).transferFrom( TestERC20(token).approve(pool, repaymentAmount);
recipient,
pool,
repaymentAmount
);
}
}
} else if (action == Action.REPAY_PARTIAL) { } else if (action == Action.REPAY_PARTIAL) {
// Repay half of the required amounts // Repay half of the required amount
for (uint256 i = 0; i < loanAmounts.length; i++) { uint256 partialRepayment = (amount + fee) / 2;
if (loanAmounts[i] > 0) { TestERC20(token).approve(pool, partialRepayment);
uint256 partialRepayment = repaymentAmounts[i] / 2;
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
partialRepayment
);
}
}
} else if (action == Action.REPAY_NO_FEE) { } else if (action == Action.REPAY_NO_FEE) {
// Repay only the principal without fee // Repay only the principal without fee (we already have it from the loan)
for (uint256 i = 0; i < loanAmounts.length; i++) { TestERC20(token).approve(pool, amount);
if (loanAmounts[i] > 0) {
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
loanAmounts[i]
);
}
}
} else if (action == Action.REPAY_EXACT) { } else if (action == Action.REPAY_EXACT) {
// Repay exactly what was required // Repay exactly what was required
for (uint256 i = 0; i < loanAmounts.length; i++) { uint256 repaymentAmount = amount + fee;
if (loanAmounts[i] > 0) { // Transfer the fee from payer (we have the principal from the loan)
TestERC20(tokens[i]).transferFrom( TestERC20(token).transferFrom(payer, address(this), fee);
recipient, // Approve pool to pull back the full repayment
pool, TestERC20(token).approve(pool, repaymentAmount);
repaymentAmounts[i]
);
}
}
} }
// For REPAY_NONE, do nothing (don't repay) // For REPAY_NONE, do nothing (don't approve repayment)
return keccak256("ERC3156FlashBorrower.onFlashLoan");
} }
} }
@@ -247,13 +217,11 @@ contract GasTest is Test {
/// @notice Setup a flash borrower for testing /// @notice Setup a flash borrower for testing
function setupFlashBorrower() internal returns (FlashBorrower borrower) { function setupFlashBorrower() internal returns (FlashBorrower borrower) {
// Get token addresses from the 2-token pool
IERC20[] memory tokenAddresses = pool2.allTokens();
// Deploy the borrower contract // Deploy the borrower contract
borrower = new FlashBorrower(address(pool2), tokenAddresses); borrower = new FlashBorrower(address(pool2));
// Mint tokens to alice to be used for repayments and approve borrower // Mint tokens to alice to be used for repayments and approve borrower
IERC20[] memory tokenAddresses = pool2.allTokens();
vm.startPrank(alice); vm.startPrank(alice);
for (uint256 i = 0; i < tokenAddresses.length; i++) { for (uint256 i = 0; i < tokenAddresses.length; i++) {
TestERC20(address(tokenAddresses[i])).mint(alice, INIT_BAL * 2); TestERC20(address(tokenAddresses[i])).mint(alice, INIT_BAL * 2);
@@ -441,33 +409,14 @@ contract GasTest is Test {
// Configure borrower // Configure borrower
borrower.setAction(FlashBorrower.Action.NORMAL, alice); borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request for single token (get array size from pool) // Get first token from pool
IERC20[] memory poolTokens = pool2.allTokens(); IERC20[] memory poolTokens = pool2.allTokens();
uint256[] memory amounts = new uint256[](poolTokens.length); address token = address(poolTokens[0]);
amounts[0] = 1000; uint256 amount = 1000;
// Execute flash loan 10 times to measure gas // Execute flash loan 10 times to measure gas
for (uint256 i = 0; i < 10; i++) { for (uint256 i = 0; i < 10; i++) {
borrower.flash(amounts); borrower.flash(token, amount);
}
}
/// @notice Gas measurement: flash with multiple tokens
function testFlashGasMultipleTokens() public {
FlashBorrower borrower = setupFlashBorrower();
// Configure borrower
borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request for multiple tokens (get array size from pool)
IERC20[] memory poolTokens = pool2.allTokens();
uint256[] memory amounts = new uint256[](poolTokens.length);
amounts[0] = 1000;
amounts[1] = 2000;
// Execute flash loan 10 times to measure gas
for (uint256 i = 0; i < 10; i++) {
borrower.flash(amounts);
} }
} }
} }

View File

@@ -9,103 +9,76 @@ import "../src/LMSRStabilized.sol";
import "../src/PartyPool.sol"; import "../src/PartyPool.sol";
// Import the flash callback interface // Import the flash callback interface
import "../src/IPartyFlashCallback.sol"; import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol";
import {PartyPlanner} from "../src/PartyPlanner.sol"; import {PartyPlanner} from "../src/PartyPlanner.sol";
import {Deploy} from "../src/Deploy.sol"; import {Deploy} from "../src/Deploy.sol";
import {PartyPoolView} from "../src/PartyPoolView.sol"; import {PartyPoolView} from "../src/PartyPoolView.sol";
/// @notice Test contract that implements the flash callback for testing flash loans /// @notice Test contract that implements the flash callback for testing flash loans
contract FlashBorrower is IPartyFlashCallback { contract FlashBorrower is IERC3156FlashBorrower {
enum Action { enum Action {
NORMAL, // Normal repayment NORMAL, // Normal repayment
REPAY_NONE, // Don't repay anything REPAY_NONE, // Don't repay anything
REPAY_PARTIAL, // Repay less than required REPAY_PARTIAL, // Repay less than required
REPAY_NO_FEE, // Repay only the principal without fee REPAY_NO_FEE, // Repay only the principal without fee
REPAY_EXACT, // Repay exactly the required amount REPAY_EXACT // Repay exactly the required amount
REPAY_EXTRA // Repay more than required (donation)
} }
Action public action; Action public action;
address public pool; address public pool;
address public recipient; address public payer;
address[] public tokens;
constructor(address _pool, address[] memory _tokens) { constructor(address _pool) {
pool = _pool; pool = _pool;
tokens = _tokens;
} }
function setAction(Action _action, address _recipient) external { function setAction(Action _action, address _payer) external {
action = _action; action = _action;
recipient = _recipient; payer = _payer;
} }
function flash(uint256[] memory amounts) external { function flash(address token, uint256 amount) external {
PartyPool(pool).flash(recipient, amounts, ""); PartyPool(pool).flashLoan(IERC3156FlashBorrower(address(this)), token, amount, "");
} }
function partyFlashCallback( function onFlashLoan(
uint256[] memory loanAmounts, address initiator,
uint256[] memory repaymentAmounts, address token,
uint256 amount,
uint256 fee,
bytes calldata /* data */ bytes calldata /* data */
) external override { ) external override returns (bytes32) {
require(msg.sender == pool, "Callback not called by pool"); require(msg.sender == pool, "Callback not called by pool");
if (action == Action.NORMAL || action == Action.REPAY_EXTRA) { if (action == Action.NORMAL) {
// Normal or extra repayment - transfer required amounts back to pool // Normal repayment
for (uint256 i = 0; i < loanAmounts.length; i++) { // We received 'amount' from the pool, need to pay back amount + fee
if (loanAmounts[i] > 0) { uint256 repaymentAmount = amount + fee;
uint256 repaymentAmount = repaymentAmounts[i];
// For REPAY_EXTRA, add 1 to each repayment // Transfer the fee from payer to this contract
if (action == Action.REPAY_EXTRA) { // (we already have the principal 'amount' from the flash loan)
repaymentAmount += 1; TestERC20(token).transferFrom(payer, address(this), fee);
}
// Transfer from recipient back to pool // Approve pool to pull back the full repayment
TestERC20(tokens[i]).transferFrom( TestERC20(token).approve(pool, repaymentAmount);
recipient,
pool,
repaymentAmount
);
}
}
} else if (action == Action.REPAY_PARTIAL) { } else if (action == Action.REPAY_PARTIAL) {
// Repay half of the required amounts // Repay half of the required amount
for (uint256 i = 0; i < loanAmounts.length; i++) { uint256 partialRepayment = (amount + fee) / 2;
if (loanAmounts[i] > 0) { TestERC20(token).approve(pool, partialRepayment);
uint256 partialRepayment = repaymentAmounts[i] / 2;
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
partialRepayment
);
}
}
} else if (action == Action.REPAY_NO_FEE) { } else if (action == Action.REPAY_NO_FEE) {
// Repay only the principal without fee // Repay only the principal without fee (we already have it from the loan)
for (uint256 i = 0; i < loanAmounts.length; i++) { TestERC20(token).approve(pool, amount);
if (loanAmounts[i] > 0) {
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
loanAmounts[i]
);
}
}
} else if (action == Action.REPAY_EXACT) { } else if (action == Action.REPAY_EXACT) {
// Repay exactly what was required // Repay exactly what was required
for (uint256 i = 0; i < loanAmounts.length; i++) { uint256 repaymentAmount = amount + fee;
if (loanAmounts[i] > 0) { // Transfer the fee from payer (we have the principal from the loan)
TestERC20(tokens[i]).transferFrom( TestERC20(token).transferFrom(payer, address(this), fee);
recipient, // Approve pool to pull back the full repayment
pool, TestERC20(token).approve(pool, repaymentAmount);
repaymentAmounts[i]
);
}
}
} }
// For REPAY_NONE, do nothing (don't repay) // For REPAY_NONE, do nothing (don't approve repayment)
return keccak256("ERC3156FlashBorrower.onFlashLoan");
} }
} }
@@ -846,14 +819,8 @@ contract PartyPoolTest is Test {
/// @notice Setup a flash borrower for testing /// @notice Setup a flash borrower for testing
function setupFlashBorrower() internal returns (FlashBorrower borrower) { function setupFlashBorrower() internal returns (FlashBorrower borrower) {
// Create array of token addresses for borrower
address[] memory tokenAddresses = new address[](3);
tokenAddresses[0] = address(token0);
tokenAddresses[1] = address(token1);
tokenAddresses[2] = address(token2);
// Deploy the borrower contract // Deploy the borrower contract
borrower = new FlashBorrower(address(pool), tokenAddresses); borrower = new FlashBorrower(address(pool));
// Mint tokens to alice to be used for repayments // Mint tokens to alice to be used for repayments
token0.mint(alice, INIT_BAL * 2); token0.mint(alice, INIT_BAL * 2);
@@ -876,18 +843,17 @@ contract PartyPoolTest is Test {
borrower.setAction(FlashBorrower.Action.NORMAL, alice); borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request for token0 only // Create loan request for token0 only
uint256[] memory amounts = new uint256[](3); uint256 amount = 1000;
amounts[0] = 1000; // Only borrow token0
// Record balances before flash // Record balances before flash
uint256 aliceToken0Before = token0.balanceOf(alice); uint256 aliceToken0Before = token0.balanceOf(alice);
uint256 poolToken0Before = token0.balanceOf(address(pool)); uint256 poolToken0Before = token0.balanceOf(address(pool));
// Execute flash loan // Execute flash loan
borrower.flash(amounts); borrower.flash(address(token0), amount);
// Net change for alice should equal the flash fee (principal is returned during repayment) // Net change for alice should equal the flash fee (principal is returned during repayment)
uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation uint256 fee = (amount * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation
uint256 expectedAliceDecrease = fee; uint256 expectedAliceDecrease = fee;
assertEq( assertEq(
aliceToken0Before - token0.balanceOf(alice), aliceToken0Before - token0.balanceOf(alice),
@@ -903,126 +869,6 @@ contract PartyPoolTest is Test {
); );
} }
/// @notice Test flash loan with multiple tokens
function testFlashLoanMultipleTokens() public {
FlashBorrower borrower = setupFlashBorrower();
// Configure borrower to repay normally
borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request for all tokens
uint256[] memory amounts = new uint256[](3);
amounts[0] = 1000;
amounts[1] = 2000;
amounts[2] = 3000;
// Record balances before flash
uint256[] memory aliceBalancesBefore = new uint256[](3);
uint256[] memory poolBalancesBefore = new uint256[](3);
aliceBalancesBefore[0] = token0.balanceOf(alice);
aliceBalancesBefore[1] = token1.balanceOf(alice);
aliceBalancesBefore[2] = token2.balanceOf(alice);
poolBalancesBefore[0] = token0.balanceOf(address(pool));
poolBalancesBefore[1] = token1.balanceOf(address(pool));
poolBalancesBefore[2] = token2.balanceOf(address(pool));
// Execute flash loan
borrower.flash(amounts);
// Check balances for each token
for (uint256 i = 0; i < 3; i++) {
uint256 fee = (amounts[i] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation
uint256 expectedAliceDecrease = fee;
IERC20 token;
if (i == 0) token = token0;
else if (i == 1) token = token1;
else token = token2;
// Net change for Alice should equal the flash fee for this token (principal was returned)
assertEq(
aliceBalancesBefore[i] - token.balanceOf(alice),
expectedAliceDecrease,
"Alice should pay flash fee for token"
);
// Pool's balance increased by fee
assertEq(
token.balanceOf(address(pool)),
poolBalancesBefore[i] + fee,
"Pool should receive fee for token"
);
}
}
/// @notice Test flash loan with some zero amounts (should be skipped)
function testFlashLoanWithZeroAmounts() public {
FlashBorrower borrower = setupFlashBorrower();
// Configure borrower to repay normally
borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request with mix of zero and non-zero amounts
uint256[] memory amounts = new uint256[](3);
amounts[0] = 0; // Zero - should be skipped
amounts[1] = 2000; // Non-zero
amounts[2] = 0; // Zero - should be skipped
// Record balances before flash
uint256 aliceToken1Before = token1.balanceOf(alice);
uint256 poolToken1Before = token1.balanceOf(address(pool));
// Tokens that should remain unchanged
uint256 aliceToken0Before = token0.balanceOf(alice);
uint256 aliceToken2Before = token2.balanceOf(alice);
uint256 poolToken0Before = token0.balanceOf(address(pool));
uint256 poolToken2Before = token2.balanceOf(address(pool));
// Execute flash loan
borrower.flash(amounts);
// Check token1 balances changed appropriately
uint256 fee = (amounts[1] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation
uint256 expectedAliceDecrease = fee;
assertEq(
aliceToken1Before - token1.balanceOf(alice),
expectedAliceDecrease,
"Alice should pay flash fee for token1"
);
assertEq(
token1.balanceOf(address(pool)),
poolToken1Before + fee,
"Pool should receive fee for token1"
);
// Check token0 and token2 balances remained unchanged
assertEq(token0.balanceOf(alice), aliceToken0Before, "Alice token0 balance should be unchanged");
assertEq(token2.balanceOf(alice), aliceToken2Before, "Alice token2 balance should be unchanged");
assertEq(token0.balanceOf(address(pool)), poolToken0Before, "Pool token0 balance should be unchanged");
assertEq(token2.balanceOf(address(pool)), poolToken2Before, "Pool token2 balance should be unchanged");
}
/// @notice Test that flash reverts when all amounts are zero
function testFlashLoanAllZeroAmountsReverts() public {
FlashBorrower borrower = setupFlashBorrower();
// Configure borrower to repay normally
borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request with all zeros
uint256[] memory amounts = new uint256[](3);
amounts[0] = 0;
amounts[1] = 0;
amounts[2] = 0;
// Execute flash loan - should revert
vm.expectRevert(bytes("flash: no tokens requested"));
borrower.flash(amounts);
}
/// @notice Test flash loan with incorrect repayment (none) /// @notice Test flash loan with incorrect repayment (none)
function testFlashLoanNoRepaymentReverts() public { function testFlashLoanNoRepaymentReverts() public {
@@ -1032,12 +878,11 @@ contract PartyPoolTest is Test {
borrower.setAction(FlashBorrower.Action.REPAY_NONE, alice); borrower.setAction(FlashBorrower.Action.REPAY_NONE, alice);
// Create loan request // Create loan request
uint256[] memory amounts = new uint256[](3); uint256 amount = 1000;
amounts[0] = 1000;
// Execute flash loan - should revert on validation // Execute flash loan - should revert due to insufficient allowance when pool tries to pull repayment
vm.expectRevert(bytes("flash: repayment failed")); vm.expectRevert();
borrower.flash(amounts); borrower.flash(address(token0), amount);
} }
/// @notice Test flash loan with partial repayment (should revert) /// @notice Test flash loan with partial repayment (should revert)
@@ -1048,12 +893,11 @@ contract PartyPoolTest is Test {
borrower.setAction(FlashBorrower.Action.REPAY_PARTIAL, alice); borrower.setAction(FlashBorrower.Action.REPAY_PARTIAL, alice);
// Create loan request // Create loan request
uint256[] memory amounts = new uint256[](3); uint256 amount = 1000;
amounts[0] = 1000;
// Execute flash loan - should revert on validation // Execute flash loan - should revert due to insufficient allowance when pool tries to pull full repayment
vm.expectRevert(bytes("flash: repayment failed")); vm.expectRevert();
borrower.flash(amounts); borrower.flash(address(token0), amount);
} }
/// @notice Test flash loan with principal repayment but no fee (should revert) /// @notice Test flash loan with principal repayment but no fee (should revert)
@@ -1064,16 +908,15 @@ contract PartyPoolTest is Test {
borrower.setAction(FlashBorrower.Action.REPAY_NO_FEE, alice); borrower.setAction(FlashBorrower.Action.REPAY_NO_FEE, alice);
// Create loan request // Create loan request
uint256[] memory amounts = new uint256[](3); uint256 amount = 1000;
amounts[0] = 1000;
// Execute flash loan - should revert on validation if fee > 0 // Execute flash loan - should revert due to insufficient allowance if fee > 0
if (pool.flashFeePpm() > 0) { if (pool.flashFeePpm() > 0) {
vm.expectRevert(bytes("flash: repayment failed")); vm.expectRevert();
borrower.flash(amounts); borrower.flash(address(token0), amount);
} else { } else {
// If fee is zero, this should succeed // If fee is zero, this should succeed
borrower.flash(amounts); borrower.flash(address(token0), amount);
} }
} }
@@ -1085,18 +928,17 @@ contract PartyPoolTest is Test {
borrower.setAction(FlashBorrower.Action.REPAY_EXACT, alice); borrower.setAction(FlashBorrower.Action.REPAY_EXACT, alice);
// Create loan request // Create loan request
uint256[] memory amounts = new uint256[](3); uint256 amount = 1000;
amounts[0] = 1000;
// Record balances before flash // Record balances before flash
uint256 aliceToken0Before = token0.balanceOf(alice); uint256 aliceToken0Before = token0.balanceOf(alice);
uint256 poolToken0Before = token0.balanceOf(address(pool)); uint256 poolToken0Before = token0.balanceOf(address(pool));
// Execute flash loan // Execute flash loan
borrower.flash(amounts); borrower.flash(address(token0), amount);
// Check balances: net change for alice should equal the fee // Check balances: net change for alice should equal the fee
uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation uint256 fee = (amount * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation
uint256 expectedAliceDecrease = fee; uint256 expectedAliceDecrease = fee;
assertEq( assertEq(
@@ -1112,115 +954,29 @@ contract PartyPoolTest is Test {
); );
} }
/// @notice Test flash loan with extra repayment (donation, should succeed) /// @notice Test flashFee view function matches flash implementation
function testFlashLoanExtraRepayment() public { function testFlashFee() public view {
FlashBorrower borrower = setupFlashBorrower(); // Test different loan amounts
uint256[] memory testAmounts = new uint256[](3);
testAmounts[0] = 1000;
testAmounts[1] = 2000;
testAmounts[2] = 3000;
// Configure borrower to repay more than required for (uint256 i = 0; i < testAmounts.length; i++) {
borrower.setAction(FlashBorrower.Action.REPAY_EXTRA, alice); uint256 amount = testAmounts[i];
uint256 fee = viewer.flashFee(pool, address(token0), amount);
// Create loan request // Calculate expected fee
uint256[] memory amounts = new uint256[](3); uint256 expectedFee = (amount * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceiling
amounts[0] = 1000;
// Record balances before flash assertEq(
uint256 aliceToken0Before = token0.balanceOf(alice); fee,
uint256 poolToken0Before = token0.balanceOf(address(pool)); expectedFee,
"Flash fee calculation mismatch"
// Execute flash loan );
borrower.flash(amounts);
// Check balances - net change for alice should equal fee + extra donation (principal returned)
uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation
uint256 extra = 1; // borrower donates +1 per token in REPAY_EXTRA
uint256 expectedAliceDecrease = fee + extra; // fee plus donation
assertEq(
aliceToken0Before - token0.balanceOf(alice),
expectedAliceDecrease,
"Alice should pay fee + extra"
);
assertEq(
token0.balanceOf(address(pool)),
poolToken0Before + fee + extra,
"Pool should receive fee + extra"
);
}
/// @notice Test flashRepaymentAmounts matches flash implementation
function testFlashRepaymentAmounts() public view {
// Create different loan amount scenarios
uint256[][] memory testCases = new uint256[][](3);
// Case 1: Single token
testCases[0] = new uint256[](3);
testCases[0][0] = 1000;
testCases[0][1] = 0;
testCases[0][2] = 0;
// Case 2: Multiple tokens
testCases[1] = new uint256[](3);
testCases[1][0] = 1000;
testCases[1][1] = 2000;
testCases[1][2] = 3000;
// Case 3: Mix of zero and non-zero
testCases[2] = new uint256[](3);
testCases[2][0] = 0;
testCases[2][1] = 2000;
testCases[2][2] = 0;
for (uint256 i = 0; i < testCases.length; i++) {
uint256[] memory loanAmounts = testCases[i];
uint256[] memory repaymentAmounts = viewer.flashRepaymentAmounts(pool, loanAmounts);
// Verify each repayment amount is correctly calculated
for (uint256 j = 0; j < loanAmounts.length; j++) {
if (loanAmounts[j] == 0) {
// Zero loans should have zero repayment
assertEq(repaymentAmounts[j], 0, "Zero loan should have zero repayment");
} else {
// Calculate expected repayment with fee
uint256 fee = (loanAmounts[j] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceiling
uint256 expectedRepayment = loanAmounts[j] + fee;
assertEq(
repaymentAmounts[j],
expectedRepayment,
"Repayment calculation mismatch"
);
}
}
} }
} }
/// @notice Test flash with invalid recipient
function testFlashWithZeroRecipientReverts() public {
FlashBorrower borrower = setupFlashBorrower();
// Configure borrower with zero recipient
borrower.setAction(FlashBorrower.Action.NORMAL, address(0));
// Create loan request
uint256[] memory amounts = new uint256[](3);
amounts[0] = 1000;
// Execute flash loan - should revert due to zero recipient
vm.expectRevert(bytes("flash: zero recipient"));
borrower.flash(amounts);
}
/// @notice Test flash with incorrect amounts length
function testFlashWithIncorrectLengthReverts() public {
// Call flash directly with incorrect length
uint256[] memory wrongLengthAmounts = new uint256[](2); // Pool has 3 tokens
wrongLengthAmounts[0] = 1000;
wrongLengthAmounts[1] = 2000;
vm.expectRevert(bytes("flash: amounts length mismatch"));
pool.flash(alice, wrongLengthAmounts, "");
}
/// @notice Test that passing nonzero lpTokens to initialMint doesn't affect swap results /// @notice Test that passing nonzero lpTokens to initialMint doesn't affect swap results
/// compared to pools initialized with default lpTokens (0) /// compared to pools initialized with default lpTokens (0)