diff --git a/script/DeployMock.sol b/script/DeployMock.sol index 06cd761..48992bb 100644 --- a/script/DeployMock.sol +++ b/script/DeployMock.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; -import "../src/Deploy.sol"; +import "../test/Deploy.sol"; import "../src/IPartyPool.sol"; import "../src/PartyPlanner.sol"; import "../src/PartyPool.sol"; diff --git a/src/IPartyPlanner.sol b/src/IPartyPlanner.sol index 653488d..96b6510 100644 --- a/src/IPartyPlanner.sol +++ b/src/IPartyPlanner.sol @@ -122,6 +122,6 @@ interface IPartyPlanner { function mintImpl() external view returns (PartyPoolMintImpl); /// @notice Address of the swap implementation contract used by all pools created by this factory - function swapMintImpl() external view returns (PartyPoolSwapImpl); + function swapImpl() external view returns (PartyPoolSwapImpl); } diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol index 558300a..802c054 100644 --- a/src/IPartyPool.sol +++ b/src/IPartyPool.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.30; import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; +import "./IWETH9.sol"; import "./LMSRStabilized.sol"; import {IERC20Metadata} from "../lib/openzeppelin-contracts/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; @@ -70,6 +71,9 @@ interface IPartyPool is IERC20Metadata { /// @notice Returns the list of all token addresses in the pool (copy). function allTokens() external view returns (IERC20[] memory); + /// @notice Token contract used for wrapping native currency + function wrapperToken() external view returns (IWETH9); + /// @notice Per-token uint base denominators used to convert uint token amounts <-> internal Q64.64 representation. /// @dev denominators()[i] is the base for tokens[i]. These bases are chosen by deployer and must match token decimals. function denominators() external view returns (uint256[] memory); @@ -136,14 +140,12 @@ interface IPartyPool is IERC20Metadata { /// @param maxAmountIn maximum gross input allowed (inclusive of fee) /// @param limitPrice maximum acceptable marginal price (pass 0 to ignore) /// @return amountIn gross input amount to transfer (includes fee), amountOut output amount user would receive, fee fee amount taken -/* function swapAmounts( uint256 inputTokenIndex, uint256 outputTokenIndex, uint256 maxAmountIn, int128 limitPrice ) external view returns (uint256 amountIn, uint256 amountOut, uint256 fee); -*/ /// @notice Swap input token inputTokenIndex -> token outputTokenIndex. Payer must approve token inputTokenIndex. /// @dev This function transfers the exact gross input (including fee) from payer and sends the computed output to receiver. @@ -164,7 +166,7 @@ interface IPartyPool is IERC20Metadata { uint256 maxAmountIn, int128 limitPrice, uint256 deadline - ) external returns (uint256 amountIn, uint256 amountOut, uint256 fee); + ) external payable returns (uint256 amountIn, uint256 amountOut, uint256 fee); /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. /// @dev If balances prevent fully reaching the limit, the function caps and returns actuals. @@ -183,7 +185,7 @@ interface IPartyPool is IERC20Metadata { uint256 outputTokenIndex, int128 limitPrice, uint256 deadline - ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee); + ) external payable returns (uint256 amountInUsed, uint256 amountOut, uint256 fee); /// @notice Single-token mint: deposit a single token, charge swap-LMSR cost, and mint LP. /// @dev swapMint executes as an exact-in planned swap followed by proportional scaling of qInternal. @@ -200,7 +202,7 @@ interface IPartyPool is IERC20Metadata { uint256 inputTokenIndex, uint256 maxAmountIn, uint256 deadline - ) external returns (uint256 lpMinted); + ) external payable returns (uint256 lpMinted); /// @notice Burn LP tokens then swap the redeemed proportional basket into a single asset `inputTokenIndex` and send to receiver. /// @dev The function burns LP tokens (authorization via allowance if needed), sends the single-asset payout and updates LMSR state. diff --git a/src/IWETH9.sol b/src/IWETH9.sol new file mode 100644 index 0000000..93debe5 --- /dev/null +++ b/src/IWETH9.sol @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import {IERC20Metadata} from "../lib/openzeppelin-contracts/contracts/token/ERC20/extensions/IERC20Metadata.sol"; + + +interface IWETH9 is IERC20Metadata { + function deposit() external payable; + function withdraw(uint wad) external; +} diff --git a/src/PartyPlanner.sol b/src/PartyPlanner.sol index 8ca4d3e..d77a034 100644 --- a/src/PartyPlanner.sol +++ b/src/PartyPlanner.sol @@ -4,11 +4,12 @@ pragma solidity ^0.8.30; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; import {IPartyPlanner} from "./IPartyPlanner.sol"; -import {LMSRStabilized} from "./LMSRStabilized.sol"; import {IPartyPool} from "./IPartyPool.sol"; +import {IWETH9} from "./IWETH9.sol"; +import {LMSRStabilized} from "./LMSRStabilized.sol"; +import {IPartyPoolDeployer} from "./PartyPoolDeployer.sol"; import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol"; import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol"; -import {IPartyPoolDeployer} from "./PartyPoolDeployer.sol"; /// @title PartyPlanner /// @notice Factory contract for creating and tracking PartyPool instances @@ -21,8 +22,8 @@ contract PartyPlanner is IPartyPlanner { function mintImpl() external view returns (PartyPoolMintImpl) { return MINT_IMPL; } /// @notice Address of the SwapMint implementation contract used by all pools created by this factory - PartyPoolSwapImpl private immutable SWAP_MINT_IMPL; - function swapMintImpl() external view returns (PartyPoolSwapImpl) { return SWAP_MINT_IMPL; } + PartyPoolSwapImpl private immutable SWAP_IMPL; + function swapImpl() external view returns (PartyPoolSwapImpl) { return SWAP_IMPL; } /// @notice Protocol fee share (ppm) applied to fees collected by pools created by this planner uint256 private immutable PROTOCOL_FEE_PPM; @@ -32,6 +33,7 @@ contract PartyPlanner is IPartyPlanner { address private immutable PROTOCOL_FEE_ADDRESS; function protocolFeeAddress() external view returns (address) { return PROTOCOL_FEE_ADDRESS; } + IWETH9 private immutable WRAPPER; IPartyPoolDeployer private immutable NORMAL_POOL_DEPLOYER; IPartyPoolDeployer private immutable BALANCED_PAIR_DEPLOYER; @@ -42,20 +44,22 @@ contract PartyPlanner is IPartyPlanner { mapping(IERC20 => bool) private _tokenSupported; mapping(IERC20 => IPartyPool[]) private _poolsByToken; - /// @param _swapMintImpl address of the SwapMint implementation contract to be used by all pools + /// @param _swapImpl address of the Swap implementation contract to be used by all pools /// @param _mintImpl address of the Mint implementation contract to be used by all pools /// @param _protocolFeePpm protocol fee share (ppm) to be used for pools created by this planner /// @param _protocolFeeAddress recipient address for protocol fees for pools created by this planner (may be address(0)) constructor( - PartyPoolSwapImpl _swapMintImpl, + IWETH9 _wrapper, + PartyPoolSwapImpl _swapImpl, PartyPoolMintImpl _mintImpl, IPartyPoolDeployer _deployer, IPartyPoolDeployer _balancedPairDeployer, uint256 _protocolFeePpm, address _protocolFeeAddress ) { - require(address(_swapMintImpl) != address(0), "Planner: swapMintImpl address cannot be zero"); - SWAP_MINT_IMPL = _swapMintImpl; + WRAPPER = _wrapper; + require(address(_swapImpl) != address(0), "Planner: swapImpl address cannot be zero"); + SWAP_IMPL = _swapImpl; require(address(_mintImpl) != address(0), "Planner: mintImpl address cannot be zero"); MINT_IMPL = _mintImpl; require(address(_deployer) != address(0), "Planner: deployer address cannot be zero"); @@ -107,7 +111,8 @@ contract PartyPlanner is IPartyPlanner { _flashFeePpm, PROTOCOL_FEE_PPM, PROTOCOL_FEE_ADDRESS, - PartyPoolSwapImpl(SWAP_MINT_IMPL), + WRAPPER, + SWAP_IMPL, MINT_IMPL ); diff --git a/src/PartyPool.sol b/src/PartyPool.sol index 37144e1..3192db9 100644 --- a/src/PartyPool.sol +++ b/src/PartyPool.sol @@ -17,6 +17,7 @@ import {Proxy} from "../lib/openzeppelin-contracts/contracts/proxy/Proxy.sol"; import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol"; import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; import {IERC3156FlashLender} from "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashLender.sol"; +import {IWETH9} from "./IWETH9.sol"; /// @title PartyPool - LMSR-backed multi-asset pool with LP ERC20 token /// @notice A multi-asset liquidity pool backed by the LMSRStabilized pricing model. @@ -36,6 +37,8 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { using LMSRStabilized for LMSRStabilized.State; using SafeERC20 for IERC20; + function wrapperToken() external view returns (IWETH9) { return WRAPPER_TOKEN; } + /// @notice Liquidity parameter κ (Q64.64) used by the LMSR kernel: b = κ * S(q) /// @dev Pool is constructed with a fixed κ. Clients that previously passed tradeFrac/targetSlippage /// should use LMSRStabilized.computeKappaFromSlippage(...) to derive κ and pass it here. @@ -104,9 +107,13 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 flashFeePpm_, uint256 protocolFeePpm_, // NEW: protocol share of fees (ppm) address protocolFeeAddress_, // NEW: recipient for collected protocol tokens + IWETH9 wrapperToken_, PartyPoolSwapImpl swapImpl_, PartyPoolMintImpl mintImpl_ - ) ERC20External(name_, symbol_) { + ) + PartyPoolBase(wrapperToken_) + ERC20External(name_, symbol_) + { require(tokens_.length > 1, "Pool: need >1 asset"); require(tokens_.length == bases_.length, "Pool: lengths mismatch"); tokens = tokens_; @@ -198,7 +205,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { Swaps ---------------------- */ -/* function swapAmounts( uint256 inputTokenIndex, uint256 outputTokenIndex, @@ -208,7 +214,6 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { (uint256 grossIn, uint256 outUint,,,, uint256 feeUint) = _quoteSwapExactIn(inputTokenIndex, outputTokenIndex, maxAmountIn, limitPrice); return (grossIn, outUint, feeUint); } -*/ /// @inheritdoc IPartyPool function swap( @@ -219,7 +224,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 maxAmountIn, int128 limitPrice, uint256 deadline - ) external nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { + ) external payable nonReentrant returns (uint256 amountIn, uint256 amountOut, uint256 fee) { require(deadline == 0 || block.timestamp <= deadline, "swap: deadline exceeded"); // Compute amounts using the same path as views @@ -230,15 +235,15 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { IERC20 tokenIn = tokens[inputTokenIndex]; IERC20 tokenOut = tokens[outputTokenIndex]; - // Transfer tokens in - tokenIn.safeTransferFrom(payer, address(this), totalTransferAmount); + // Transfer tokens in via centralized helper + _receiveTokenFrom(payer, tokenIn, totalTransferAmount); // Compute on-chain balances as: onchain = cached + owed (+/- transfer) uint256 balIAfter = cachedUintBalances[inputTokenIndex] + protocolFeesOwed[inputTokenIndex] + totalTransferAmount; uint256 balJAfter = cachedUintBalances[outputTokenIndex] + protocolFeesOwed[outputTokenIndex] - amountOutUint; - // Transfer output to receiver - tokenOut.safeTransfer(receiver, amountOutUint); + // Transfer output to receiver via centralized helper + _sendTokenTo(tokenOut, receiver, amountOutUint); // Accrue protocol share (floor) from the fee on input token if (PROTOCOL_FEE_PPM > 0 && feeUint > 0) { @@ -260,6 +265,8 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { emit Swap(payer, receiver, tokenIn, tokenOut, totalTransferAmount, amountOutUint); + _refund(); + return (totalTransferAmount, amountOutUint, feeUint); } @@ -324,7 +331,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 outputTokenIndex, int128 limitPrice, uint256 deadline - ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { + ) external payable returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { bytes memory data = abi.encodeWithSelector( PartyPoolSwapImpl.swapToLimit.selector, payer, @@ -355,7 +362,7 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 inputTokenIndex, uint256 maxAmountIn, uint256 deadline - ) external returns (uint256 lpMinted) { + ) external payable returns (uint256 lpMinted) { bytes memory data = abi.encodeWithSelector( PartyPoolMintImpl.swapMint.selector, payer, @@ -431,9 +438,9 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { } } - token.safeTransfer(address(receiver), amount); + _sendTokenTo(token, address(receiver), amount); require(receiver.onFlashLoan(msg.sender, address(token), amount, fee, data) == FLASH_CALLBACK_SUCCESS); - token.safeTransferFrom(address(receiver), address(this), amount + fee); + _receiveTokenFrom(address(receiver), token, amount + fee); // Update cached balance for the borrowed token uint256 balAfter = token.balanceOf(address(this)); @@ -458,8 +465,8 @@ contract PartyPool is PartyPoolBase, ERC20External, IPartyPool { uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); require(bal >= owed, "collect: fee > bal"); protocolFeesOwed[i] = 0; - // transfer owed tokens to protocol destination - tokens[i].safeTransfer(dest, owed); + // transfer owed tokens to protocol destination via centralized helper + _sendTokenTo(tokens[i], dest, owed); // update cached to effective onchain minus owed cachedUintBalances[i] = bal - owed; } diff --git a/src/PartyPoolBalancedPair.sol b/src/PartyPoolBalancedPair.sol index a465808..256ab38 100644 --- a/src/PartyPoolBalancedPair.sol +++ b/src/PartyPoolBalancedPair.sol @@ -2,8 +2,10 @@ pragma solidity ^0.8.30; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; +import {IWETH9} from "./IWETH9.sol"; import {LMSRStabilizedBalancedPair} from "./LMSRStabilizedBalancedPair.sol"; import {PartyPool} from "./PartyPool.sol"; +import {PartyPoolBase} from "./PartyPoolBase.sol"; import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol"; import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol"; @@ -18,9 +20,10 @@ contract PartyPoolBalancedPair is PartyPool { uint256 flashFeePpm_, uint256 protocolFeePpm_, // NEW: protocol share of fees (ppm) address protocolFeeAddress_, // NEW: recipient for collected protocol tokens + IWETH9 wrapperToken_, PartyPoolSwapImpl swapMintImpl_, PartyPoolMintImpl mintImpl_ - ) PartyPool(name_, symbol_, tokens_, bases_, kappa_, swapFeePpm_, flashFeePpm_, protocolFeePpm_, protocolFeeAddress_, swapMintImpl_, mintImpl_) + ) PartyPool(name_, symbol_, tokens_, bases_, kappa_, swapFeePpm_, flashFeePpm_, protocolFeePpm_, protocolFeeAddress_, wrapperToken_, swapMintImpl_, mintImpl_) {} function _swapAmountsForExactInput(uint256 i, uint256 j, int128 a, int128 limitPrice) internal virtual override view diff --git a/src/PartyPoolBase.sol b/src/PartyPoolBase.sol index 5212707..20ce59b 100644 --- a/src/PartyPoolBase.sol +++ b/src/PartyPoolBase.sol @@ -1,18 +1,27 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.30; +import "./IWETH9.sol"; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; -import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; -import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol"; import {ERC20Internal} from "./ERC20Internal.sol"; +import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol"; import {PartyPoolHelpers} from "./PartyPoolHelpers.sol"; +import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol"; +import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; /// @notice Abstract base contract that contains storage and internal helpers only. /// No external/public functions or constructor here — derived implementations own immutables and constructors. abstract contract PartyPoolBase is ERC20Internal, ReentrancyGuard, PartyPoolHelpers { using ABDKMath64x64 for int128; using LMSRStabilized for LMSRStabilized.State; + using SafeERC20 for IERC20; + + IWETH9 internal immutable WRAPPER_TOKEN; + + constructor( IWETH9 wrapper_ ) { + WRAPPER_TOKEN = wrapper_; + } // // Internal state (no immutables here; immutables belong to derived contracts) @@ -78,4 +87,36 @@ abstract contract PartyPoolBase is ERC20Internal, ReentrancyGuard, PartyPoolHelp return floorValue; } + /* ---------------------- + Token transfer helpers (includes autowrap) + ---------------------- */ + + /// @notice Receive tokens from `payer` into the pool (address(this)) using SafeERC20 semantics. + /// @dev Note: this helper does NOT query the on-chain balance after transfer to save gas. + /// Callers should query the balance themselves when they need it (e.g., to detect fee-on-transfer tokens). + function _receiveTokenFrom(address payer, IERC20 token, uint256 amount) internal { + if( token == WRAPPER_TOKEN && msg.value >= amount ) + WRAPPER_TOKEN.deposit{value:amount}(); + else + token.safeTransferFrom(payer, address(this), amount); + } + + /// @notice Send tokens from the pool to `receiver` using SafeERC20 semantics. + /// @dev Note: this helper does NOT query the on-chain balance after transfer to save gas. + /// Callers should query the balance themselves when they need it (e.g., to detect fee-on-transfer tokens). + function _sendTokenTo(IERC20 token, address receiver, uint256 amount) internal { + if( token == WRAPPER_TOKEN ) { + WRAPPER_TOKEN.withdraw(amount); + (bool ok, ) = receiver.call{value: amount}(""); + require(ok); // todo make unwrapping optional + } + else + token.safeTransfer(receiver, amount); + } + + function _refund() internal { + uint256 bal = address(this).balance; + if(bal > 0) + payable(msg.sender).transfer(bal); + } } diff --git a/src/PartyPoolDeployer.sol b/src/PartyPoolDeployer.sol index fbdd3dd..8a79dd9 100644 --- a/src/PartyPoolDeployer.sol +++ b/src/PartyPoolDeployer.sol @@ -20,6 +20,7 @@ interface IPartyPoolDeployer { uint256 flashFeePpm_, uint256 protocolFeePpm_, address protocolFeeAddress_, + IWETH9 wrapper_, PartyPoolSwapImpl swapImpl_, PartyPoolMintImpl mintImpl_ ) external returns (IPartyPool pool); @@ -36,6 +37,7 @@ contract PartyPoolDeployer is IPartyPoolDeployer { uint256 flashFeePpm_, uint256 protocolFeePpm_, address protocolFeeAddress_, + IWETH9 wrapper_, PartyPoolSwapImpl swapImpl_, PartyPoolMintImpl mintImpl_ ) external returns (IPartyPool) { @@ -49,6 +51,7 @@ contract PartyPoolDeployer is IPartyPoolDeployer { flashFeePpm_, protocolFeePpm_, protocolFeeAddress_, + wrapper_, swapImpl_, mintImpl_ ); @@ -66,6 +69,7 @@ contract PartyPoolBalancedPairDeployer is IPartyPoolDeployer { uint256 flashFeePpm_, uint256 protocolFeePpm_, address protocolFeeAddress_, + IWETH9 wrapper_, PartyPoolSwapImpl swapImpl_, PartyPoolMintImpl mintImpl_ ) external returns (IPartyPool) { @@ -79,9 +83,9 @@ contract PartyPoolBalancedPairDeployer is IPartyPoolDeployer { flashFeePpm_, protocolFeePpm_, protocolFeeAddress_, + wrapper_, swapImpl_, mintImpl_ ); } } - diff --git a/src/PartyPoolMintImpl.sol b/src/PartyPoolMintImpl.sol index a4c41e5..861c782 100644 --- a/src/PartyPoolMintImpl.sol +++ b/src/PartyPoolMintImpl.sol @@ -4,8 +4,10 @@ pragma solidity ^0.8.30; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; +import {ReentrancyGuard} from "../lib/openzeppelin-contracts/contracts/utils/ReentrancyGuard.sol"; import {ERC20Internal} from "./ERC20Internal.sol"; import {IPartyPool} from "./IPartyPool.sol"; +import {IWETH9} from "./IWETH9.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol"; import {PartyPoolBase} from "./PartyPoolBase.sol"; @@ -17,6 +19,7 @@ contract PartyPoolMintImpl is PartyPoolBase { using LMSRStabilized for LMSRStabilized.State; using SafeERC20 for IERC20; + constructor(IWETH9 wrapper_) PartyPoolBase(wrapper_) {} // // Initialization Mint @@ -62,7 +65,7 @@ contract PartyPoolMintImpl is PartyPoolBase { // Regular Mint and Burn // - function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external nonReentrant + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external payable nonReentrant returns (uint256 lpMinted) { require(deadline == 0 || block.timestamp <= deadline, "mint: deadline exceeded"); uint256 n = tokens.length; @@ -82,7 +85,7 @@ contract PartyPoolMintImpl is PartyPoolBase { // Transfer in all token amounts for (uint i = 0; i < n; ) { if (depositAmounts[i] > 0) { - tokens[i].safeTransferFrom(payer, address(this), depositAmounts[i]); + _receiveTokenFrom(payer, tokens[i], depositAmounts[i]); } unchecked { i++; } } @@ -124,6 +127,9 @@ contract PartyPoolMintImpl is PartyPoolBase { _mint(receiver, actualLpToMint); emit IPartyPool.Mint(payer, receiver, depositAmounts, actualLpToMint); + + _refund(); + return actualLpToMint; } @@ -151,7 +157,7 @@ contract PartyPoolMintImpl is PartyPoolBase { // Transfer underlying tokens out to receiver according to computed proportions for (uint i = 0; i < n; ) { if (withdrawAmounts[i] > 0) { - tokens[i].safeTransfer(receiver, withdrawAmounts[i]); + _sendTokenTo(tokens[i], receiver, withdrawAmounts[i]); } unchecked { i++; } } @@ -366,8 +372,8 @@ contract PartyPoolMintImpl is PartyPoolBase { uint256 totalTransfer = amountInUint + feeUintActual; require(totalTransfer > 0 && totalTransfer <= maxAmountIn, "swapMint: transfer exceeds max"); - // Transfer tokens from payer (assume standard ERC20 without transfer fees) - tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransfer); + // Transfer tokens from payer (assume standard ERC20 without transfer fees) via helper + _receiveTokenFrom(payer, tokens[inputTokenIndex], totalTransfer); // Accrue protocol share (floor) from the fee on the input token uint256 protoShare = 0; @@ -514,8 +520,8 @@ contract PartyPoolMintImpl is PartyPoolBase { } } - // Transfer the payout to receiver - tokens[inputTokenIndex].safeTransfer(receiver, amountOutUint); + // Transfer the payout to receiver via centralized helper + _sendTokenTo(tokens[inputTokenIndex], receiver, amountOutUint); // Burn LP tokens from payer (authorization via allowance) if (msg.sender != payer) { diff --git a/src/PartyPoolSwapImpl.sol b/src/PartyPoolSwapImpl.sol index 48e831f..de5818c 100644 --- a/src/PartyPoolSwapImpl.sol +++ b/src/PartyPoolSwapImpl.sol @@ -4,9 +4,10 @@ pragma solidity ^0.8.30; import {ABDKMath64x64} from "../lib/abdk-libraries-solidity/ABDKMath64x64.sol"; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; import {SafeERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; +import {IPartyPool} from "./IPartyPool.sol"; +import {IWETH9} from "./IWETH9.sol"; import {LMSRStabilized} from "./LMSRStabilized.sol"; import {PartyPoolBase} from "./PartyPoolBase.sol"; -import {IPartyPool} from "./IPartyPool.sol"; /// @title PartyPoolSwapMintImpl - Implementation contract for swapMint and burnSwap functions /// @notice This contract contains the swapMint and burnSwap implementation that will be called via delegatecall @@ -16,6 +17,8 @@ contract PartyPoolSwapImpl is PartyPoolBase { using LMSRStabilized for LMSRStabilized.State; using SafeERC20 for IERC20; + constructor(IWETH9 wrapper_) PartyPoolBase(wrapper_) {} + function swapToLimitAmounts( uint256 inputTokenIndex, uint256 outputTokenIndex, @@ -55,7 +58,7 @@ contract PartyPoolSwapImpl is PartyPoolBase { uint256 deadline, uint256 swapFeePpm, uint256 protocolFeePpm - ) external returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { + ) external payable returns (uint256 amountInUsed, uint256 amountOut, uint256 fee) { uint256 n = tokens.length; require(inputTokenIndex < n && outputTokenIndex < n, "swapToLimit: idx"); require(limitPrice > int128(0), "swapToLimit: limit <= 0"); @@ -70,12 +73,12 @@ contract PartyPoolSwapImpl is PartyPoolBase { _quoteSwapToLimit(inputTokenIndex, outputTokenIndex, limitPrice, swapFeePpm); // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) - tokens[inputTokenIndex].safeTransferFrom(payer, address(this), totalTransferAmount); + _receiveTokenFrom(payer, tokens[inputTokenIndex], totalTransferAmount); uint256 balIAfter = IERC20(tokens[inputTokenIndex]).balanceOf(address(this)); require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); // Transfer output to receiver and verify exact decrease - tokens[outputTokenIndex].safeTransfer(receiver, amountOutUint); + _sendTokenTo(tokens[outputTokenIndex], receiver, amountOutUint); uint256 balJAfter = IERC20(tokens[outputTokenIndex]).balanceOf(address(this)); require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); @@ -100,6 +103,8 @@ contract PartyPoolSwapImpl is PartyPoolBase { // Maintain original event semantics (logs input without fee) emit IPartyPool.Swap(payer, receiver, tokens[inputTokenIndex], tokens[outputTokenIndex], amountInUsedUint, amountOutUint); + _refund(); + return (amountInUsedUint, amountOutUint, feeUint); } diff --git a/src/Deploy.sol b/test/Deploy.sol similarity index 61% rename from src/Deploy.sol rename to test/Deploy.sol index 9ccb274..1504dde 100644 --- a/src/Deploy.sol +++ b/test/Deploy.sol @@ -2,22 +2,26 @@ pragma solidity ^0.8.30; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; -import {PartyPlanner} from "./PartyPlanner.sol"; -import {PartyPool} from "./PartyPool.sol"; -import {PartyPoolBalancedPair} from "./PartyPoolBalancedPair.sol"; -import {PartyPoolDeployer, PartyPoolBalancedPairDeployer} from "./PartyPoolDeployer.sol"; -import {PartyPoolMintImpl} from "./PartyPoolMintImpl.sol"; -import {PartyPoolSwapImpl} from "./PartyPoolSwapImpl.sol"; -import {PartyPoolViewer} from "./PartyPoolViewer.sol"; +import {IWETH9} from "../src/IWETH9.sol"; +import {PartyPlanner} from "../src/PartyPlanner.sol"; +import {PartyPool} from "../src/PartyPool.sol"; +import {PartyPoolBalancedPair} from "../src/PartyPoolBalancedPair.sol"; +import {PartyPoolDeployer, PartyPoolBalancedPairDeployer} from "../src/PartyPoolDeployer.sol"; +import {PartyPoolMintImpl} from "../src/PartyPoolMintImpl.sol"; +import {PartyPoolSwapImpl} from "../src/PartyPoolSwapImpl.sol"; +import {PartyPoolViewer} from "../src/PartyPoolViewer.sol"; +import {WETH9} from "./WETH9.sol"; library Deploy { address internal constant PROTOCOL_FEE_RECEIVER = 0x70997970C51812dc3A010C7d01b50e0d17dc79C8; // dev account #1 uint256 internal constant PROTOCOL_FEE_PPM = 100_000; // 10% function newPartyPlanner() internal returns (PartyPlanner) { + IWETH9 wrapper = new WETH9(); return new PartyPlanner( - new PartyPoolSwapImpl(), - new PartyPoolMintImpl(), + wrapper, + new PartyPoolSwapImpl(wrapper), + new PartyPoolMintImpl(wrapper), new PartyPoolDeployer(), new PartyPoolBalancedPairDeployer(), PROTOCOL_FEE_PPM, @@ -35,6 +39,7 @@ library Deploy { uint256 _flashFeePpm, bool _stable ) internal returns (PartyPool) { + IWETH9 wrapper = new WETH9(); return _stable && tokens_.length == 2 ? new PartyPoolBalancedPair( name_, @@ -46,8 +51,9 @@ library Deploy { _flashFeePpm, PROTOCOL_FEE_PPM, PROTOCOL_FEE_RECEIVER, - new PartyPoolSwapImpl(), - new PartyPoolMintImpl() + wrapper, + new PartyPoolSwapImpl(wrapper), + new PartyPoolMintImpl(wrapper) ) : new PartyPool( name_, @@ -59,13 +65,15 @@ library Deploy { _flashFeePpm, PROTOCOL_FEE_PPM, PROTOCOL_FEE_RECEIVER, - new PartyPoolSwapImpl(), - new PartyPoolMintImpl() + wrapper, + new PartyPoolSwapImpl(wrapper), + new PartyPoolMintImpl(wrapper) ); } function newViewer() internal returns (PartyPoolViewer) { - return new PartyPoolViewer(new PartyPoolSwapImpl(), new PartyPoolMintImpl()); + IWETH9 wrapper = new WETH9(); + return new PartyPoolViewer(new PartyPoolSwapImpl(wrapper), new PartyPoolMintImpl(wrapper)); } } diff --git a/test/GasTest.sol b/test/GasTest.sol index 55168c3..20fa5fc 100644 --- a/test/GasTest.sol +++ b/test/GasTest.sol @@ -9,7 +9,7 @@ import "../src/LMSRStabilized.sol"; import "../src/PartyPool.sol"; import "../src/PartyPlanner.sol"; import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; -import {Deploy} from "../src/Deploy.sol"; +import {Deploy} from "./Deploy.sol"; /// @notice Test contract that implements the flash callback for testing flash loans contract FlashBorrower is IERC3156FlashBorrower { diff --git a/test/PartyPlanner.t.sol b/test/PartyPlanner.t.sol index a3ffb96..6a0e85a 100644 --- a/test/PartyPlanner.t.sol +++ b/test/PartyPlanner.t.sol @@ -9,7 +9,7 @@ import {StdUtils} from "../lib/forge-std/src/StdUtils.sol"; import {Test} from "../lib/forge-std/src/Test.sol"; import {ERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/ERC20.sol"; import {IERC20} from "../lib/openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; -import {Deploy} from "../src/Deploy.sol"; +import {Deploy} from "./Deploy.sol"; import {IPartyPool} from "../src/IPartyPool.sol"; import {LMSRStabilized} from "../src/LMSRStabilized.sol"; import {PartyPlanner} from "../src/PartyPlanner.sol"; diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol index 73eab90..a97c9b7 100644 --- a/test/PartyPool.t.sol +++ b/test/PartyPool.t.sol @@ -11,7 +11,7 @@ import "../src/PartyPool.sol"; // Import the flash callback interface import "../lib/openzeppelin-contracts/contracts/interfaces/IERC3156FlashBorrower.sol"; import {PartyPlanner} from "../src/PartyPlanner.sol"; -import {Deploy} from "../src/Deploy.sol"; +import {Deploy} from "./Deploy.sol"; import {PartyPoolViewer} from "../src/PartyPoolViewer.sol"; /// @notice Test contract that implements the flash callback for testing flash loans diff --git a/test/WETH9.sol b/test/WETH9.sol new file mode 100644 index 0000000..8748a70 --- /dev/null +++ b/test/WETH9.sol @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.30; + +import {IWETH9} from "../src/IWETH9.sol"; + +contract WETH9 is IWETH9 { + string public name = "Wrapped Ether"; + string public symbol = "WETH"; + uint8 public decimals = 18; + + event Deposit(address indexed dst, uint256 wad); + event Withdrawal(address indexed src, uint256 wad); + + mapping(address => uint256) public balanceOf; + mapping(address => mapping(address => uint256)) public allowance; + + receive() external payable { + deposit(); + } + + function deposit() public payable { + balanceOf[msg.sender] += msg.value; + emit Deposit(msg.sender, msg.value); + } + + function withdraw(uint256 wad) public { + require(balanceOf[msg.sender] >= wad, ""); + balanceOf[msg.sender] -= wad; + payable(msg.sender).transfer(wad); + emit Withdrawal(msg.sender, wad); + } + + function totalSupply() public view returns (uint256) { + return address(this).balance; + } + + function approve(address guy, uint256 wad) public returns (bool) { + allowance[msg.sender][guy] = wad; + emit Approval(msg.sender, guy, wad); + return true; + } + + function transfer(address dst, uint256 wad) public returns (bool) { + return transferFrom(msg.sender, dst, wad); + } + + function transferFrom( + address src, + address dst, + uint256 wad + ) public returns (bool) { + require(balanceOf[src] >= wad, ""); + + if ( + src != msg.sender && allowance[src][msg.sender] != type(uint256).max + ) { + require(allowance[src][msg.sender] >= wad, ""); + allowance[src][msg.sender] -= wad; + } + + balanceOf[src] -= wad; + balanceOf[dst] += wad; + + emit Transfer(src, dst, wad); + + return true; + } +}