flashLoan rewritten as ERC-3156

This commit is contained in:
tim
2025-10-07 14:12:27 -04:00
parent 677ce4886c
commit ef039aa57e
5 changed files with 199 additions and 505 deletions

View File

@@ -8,104 +8,74 @@ import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import "../src/LMSRStabilized.sol";
import "../src/PartyPool.sol";
import "../src/PartyPlanner.sol";
import "../src/IPartyFlashCallback.sol";
import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol";
import {Deploy} from "../src/Deploy.sol";
/// @notice Test contract that implements the flash callback for testing flash loans
contract FlashBorrower is IPartyFlashCallback {
contract FlashBorrower is IERC3156FlashBorrower {
enum Action {
NORMAL, // Normal repayment
REPAY_NONE, // Don't repay anything
REPAY_PARTIAL, // Repay less than required
REPAY_NO_FEE, // Repay only the principal without fee
REPAY_EXACT, // Repay exactly the required amount
REPAY_EXTRA // Repay more than required (donation)
REPAY_EXACT // Repay exactly the required amount
}
Action public action;
address public pool;
address public recipient;
address[] public tokens;
address public payer;
constructor(address _pool, IERC20[] memory _tokens) {
constructor(address _pool) {
pool = _pool;
tokens = new address[](_tokens.length);
for (uint i = 0; i < _tokens.length; i++) {
tokens[i] = address(_tokens[i]);
}
}
function setAction(Action _action, address _recipient) external {
function setAction(Action _action, address _payer) external {
action = _action;
recipient = _recipient;
payer = _payer;
}
function flash(uint256[] memory amounts) external {
PartyPool(pool).flash(recipient, amounts, "");
function flash(address token, uint256 amount) external {
PartyPool(pool).flashLoan(IERC3156FlashBorrower(address(this)), token, amount, "");
}
function partyFlashCallback(
uint256[] memory loanAmounts,
uint256[] memory repaymentAmounts,
function onFlashLoan(
address initiator,
address token,
uint256 amount,
uint256 fee,
bytes calldata /* data */
) external override {
) external override returns (bytes32) {
require(msg.sender == pool, "Callback not called by pool");
if (action == Action.NORMAL || action == Action.REPAY_EXTRA) {
// Normal or extra repayment - transfer required amounts back to pool
for (uint256 i = 0; i < loanAmounts.length; i++) {
if (loanAmounts[i] > 0) {
uint256 repaymentAmount = repaymentAmounts[i];
if (action == Action.NORMAL) {
// Normal repayment
// We received 'amount' from the pool, need to pay back amount + fee
uint256 repaymentAmount = amount + fee;
// For REPAY_EXTRA, add 1 to each repayment
if (action == Action.REPAY_EXTRA) {
repaymentAmount += 1;
}
// Transfer the fee from payer to this contract
// (we already have the principal 'amount' from the flash loan)
TestERC20(token).transferFrom(payer, address(this), fee);
// Transfer from recipient back to pool
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
repaymentAmount
);
}
}
// Approve pool to pull back the full repayment
TestERC20(token).approve(pool, repaymentAmount);
} else if (action == Action.REPAY_PARTIAL) {
// Repay half of the required amounts
for (uint256 i = 0; i < loanAmounts.length; i++) {
if (loanAmounts[i] > 0) {
uint256 partialRepayment = repaymentAmounts[i] / 2;
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
partialRepayment
);
}
}
// Repay half of the required amount
uint256 partialRepayment = (amount + fee) / 2;
TestERC20(token).approve(pool, partialRepayment);
} else if (action == Action.REPAY_NO_FEE) {
// Repay only the principal without fee
for (uint256 i = 0; i < loanAmounts.length; i++) {
if (loanAmounts[i] > 0) {
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
loanAmounts[i]
);
}
}
// Repay only the principal without fee (we already have it from the loan)
TestERC20(token).approve(pool, amount);
} else if (action == Action.REPAY_EXACT) {
// Repay exactly what was required
for (uint256 i = 0; i < loanAmounts.length; i++) {
if (loanAmounts[i] > 0) {
TestERC20(tokens[i]).transferFrom(
recipient,
pool,
repaymentAmounts[i]
);
}
}
uint256 repaymentAmount = amount + fee;
// Transfer the fee from payer (we have the principal from the loan)
TestERC20(token).transferFrom(payer, address(this), fee);
// Approve pool to pull back the full repayment
TestERC20(token).approve(pool, repaymentAmount);
}
// For REPAY_NONE, do nothing (don't repay)
// For REPAY_NONE, do nothing (don't approve repayment)
return keccak256("ERC3156FlashBorrower.onFlashLoan");
}
}
@@ -247,13 +217,11 @@ contract GasTest is Test {
/// @notice Setup a flash borrower for testing
function setupFlashBorrower() internal returns (FlashBorrower borrower) {
// Get token addresses from the 2-token pool
IERC20[] memory tokenAddresses = pool2.allTokens();
// Deploy the borrower contract
borrower = new FlashBorrower(address(pool2), tokenAddresses);
borrower = new FlashBorrower(address(pool2));
// Mint tokens to alice to be used for repayments and approve borrower
IERC20[] memory tokenAddresses = pool2.allTokens();
vm.startPrank(alice);
for (uint256 i = 0; i < tokenAddresses.length; i++) {
TestERC20(address(tokenAddresses[i])).mint(alice, INIT_BAL * 2);
@@ -441,33 +409,14 @@ contract GasTest is Test {
// Configure borrower
borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request for single token (get array size from pool)
// Get first token from pool
IERC20[] memory poolTokens = pool2.allTokens();
uint256[] memory amounts = new uint256[](poolTokens.length);
amounts[0] = 1000;
address token = address(poolTokens[0]);
uint256 amount = 1000;
// Execute flash loan 10 times to measure gas
for (uint256 i = 0; i < 10; i++) {
borrower.flash(amounts);
}
}
/// @notice Gas measurement: flash with multiple tokens
function testFlashGasMultipleTokens() public {
FlashBorrower borrower = setupFlashBorrower();
// Configure borrower
borrower.setAction(FlashBorrower.Action.NORMAL, alice);
// Create loan request for multiple tokens (get array size from pool)
IERC20[] memory poolTokens = pool2.allTokens();
uint256[] memory amounts = new uint256[](poolTokens.length);
amounts[0] = 1000;
amounts[1] = 2000;
// Execute flash loan 10 times to measure gas
for (uint256 i = 0; i < 10; i++) {
borrower.flash(amounts);
borrower.flash(token, amount);
}
}
}