commit 5fb2b17b2e0f99c4598e03ef5c28af3a28b1bfb8 Author: tim Date: Mon Sep 15 14:21:56 2025 -0400 dxod repo init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..51f60db --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# Compiler files +cache/ +out/ + +docs/ +log/ +.env +.idea + +# Ignores development broadcast logs +!/broadcast +/broadcast/*/31337/ +/broadcast/**/dry-run/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..6819c28 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "lib/forge-std"] + path = lib/forge-std + url = https://github.com/foundry-rs/forge-std +[submodule "lib/abdk-libraries-solidity"] + path = lib/abdk-libraries-solidity + url = https://github.com/abdk-consulting/abdk-libraries-solidity +[submodule "lib/openzeppelin-contracts"] + path = lib/openzeppelin-contracts + url = https://github.com/OpenZeppelin/openzeppelin-contracts diff --git a/bin/mock b/bin/mock new file mode 100755 index 0000000..a9db798 --- /dev/null +++ b/bin/mock @@ -0,0 +1,49 @@ +#!/bin/bash + +# Function to cleanup processes +cleanup() { + kill $ANVIL_PID 2>/dev/null +} + +# Set up trap to handle script exit +trap cleanup EXIT + +# Create log directory if it doesn't exist +mkdir -p log + +# Run anvil in background and redirect output to log file +anvil | tee log/anvil.txt & +ANVIL_PID=$! + +# Function to check if string exists in file +check_string() { + grep -q "$1" "$2" + return $? +} + +# Wait for anvil to start (max 30 seconds) +echo "Waiting for anvil to start..." +counter=0 +while ! check_string "Listening on" "log/anvil.txt"; do + sleep 1 + counter=$((counter + 1)) + if [ $counter -ge 5 ]; then + echo "Timeout waiting for anvil to start" + exit 1 + fi +done + +# Extract bytecode using jq +BYTECODE=$(jq -r '.bytecode.object' out/PartyPool.sol/PartyPool.json) +if [ $? -ne 0 ] || [ -z "$BYTECODE" ]; then + echo "Failed to extract bytecode from PartyPool.json" + exit 1 +fi + +export BYTECODE +forge script DeployMock --broadcast + +echo "Press Ctrl+C to exit..." +while true; do + sleep 1 +done diff --git a/doc/stablecoins.csv b/doc/stablecoins.csv new file mode 100644 index 0000000..5810b39 --- /dev/null +++ b/doc/stablecoins.csv @@ -0,0 +1,47 @@ +Token Name,Token Symbol,Price,Volume (24h) +Tether,USDT,1.00,124178611175 +USDC,USDC,0.9996,15873101191 +First Digital USD,FDUSD,0.9981,6001738498 +World Liberty Financial USD,USD1,0.9994,443670921 +Ethena USDe,USDe,1.00,203389879 +Dai,DAI,0.9998,107786336 +Ripple USD,RLUSD,0.9997,95900714 +PayPal USD,PYUSD,0.9994,84926338 +Falcon USD,USDf,1.00,75128301 +EURC,EURC,1.17,50655783 +TrueUSD,TUSD,0.9976,41243610 +StabiR USD,USDR,0.9980,37757919 +EUR CoinVertible,EURCV,1.17,36001785 +StraitsX USD,XUSD,1.00,30768737 +Global Dollar,USDG,1.00,26395021 +AUSD,AUSD,0.9999,23425385 +Quantoz EURQ,EURQ,1.17,15986967 +Quantoz USDQ,USDQ,0.9993,11283404 +Elixir deUSD,DEUSD,0.9992,11801718 +Gemini Dollar,GUSD,0.9997,10831936 +Eurite,EURi,1.17,8463732 +USDD,USDD,0.9998,6457661 +BUSD,BUSD,0.9998,4320760 +Ondo US Dollar Yield,USDY,1.09,3876681 +JUSDJ,JUSDJ,1.32,3243537 +Pax Dollar,USDP,0.9998,3010075 +StabiR Euro,EURR,1.16,2581109 +Steem Dollars,SBD,0.8492,2517238 +Bucket Protocol BUCK Stablecoin,BUCK,0.9990,2375173 +Hyper USD,USDHL,0.9995,2227470 +USDP Stablecoin,USDP,0.9998,2126519 +AllUnity EUR,EURAU,1.17,1919081 +Celo Dollar,CUSD,1.00,1809631 +Worldwide USD,WUSD,1.00,1306544 +Usual USD,USD0,0.9981,3897930 +GHO,GHO,0.9997,675336 +Legacy Frax Dollar,FRAX,0.9979,513159 +STASIS EURO,EURS,1.16,23449869 +Noble Dollar,USDN,0.9996, +Frax USD,FRXUSD,0.9986, +lisUSD,lisUSD,0.9990,34118 +USDB,USDB,0.9987,952802 +Venus BUSD,vBUSD,0.02229, +MNEE,MNEE,0.9974,86346 +Anchored Coins,AEUR,1.10,26168 +Lift Dollar,USDL,0.9987,76398 diff --git a/foundry.toml b/foundry.toml new file mode 100644 index 0000000..bad1645 --- /dev/null +++ b/foundry.toml @@ -0,0 +1,17 @@ +[profile.default] +src = "src" +out = "out" +libs = ["lib"] +remappings = [ + '@openzeppelin/=lib/openzeppelin-contracts/', + '@abdk/=lib/abdk-libraries-solidity/', +] +optimizer=true +optimizer_runs=999999999 +viaIR=true +gas_reports = ['PartyPool'] + +[lint] +exclude_lints=['mixed-case-variable', 'unaliased-plain-import', ] + +# See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options diff --git a/lib/abdk-libraries-solidity b/lib/abdk-libraries-solidity new file mode 160000 index 0000000..5e1e7c1 --- /dev/null +++ b/lib/abdk-libraries-solidity @@ -0,0 +1 @@ +Subproject commit 5e1e7c11b35f8313d3f7ce11c1b86320d7c0b554 diff --git a/lib/forge-std b/lib/forge-std new file mode 160000 index 0000000..8bbcf6e --- /dev/null +++ b/lib/forge-std @@ -0,0 +1 @@ +Subproject commit 8bbcf6e3f8f62f419e5429a0bd89331c85c37824 diff --git a/lib/openzeppelin-contracts b/lib/openzeppelin-contracts new file mode 160000 index 0000000..c64a1ed --- /dev/null +++ b/lib/openzeppelin-contracts @@ -0,0 +1 @@ +Subproject commit c64a1edb67b6e3f4a15cca8909c9482ad33a02b0 diff --git a/research/LMSRComparisonAnalysis.py b/research/LMSRComparisonAnalysis.py new file mode 100644 index 0000000..6d7159d --- /dev/null +++ b/research/LMSRComparisonAnalysis.py @@ -0,0 +1,354 @@ +# lmsr_vs_cp_sim.py +# Requires: Python 3.9+, numpy, matplotlib +# Optional: seaborn (for prettier plots) + +import math +import numpy as np +import matplotlib.pyplot as plt +try: + import seaborn as sns + sns.set_context("talk") + sns.set_style("whitegrid") +except Exception: + pass + + +# --------------------------- +# Core AMM primitives +# --------------------------- + +def cp_price(X, Y): + # Instantaneous marginal price p = dy/dx for constant product + return Y / X + +def cp_trade_y_out(X, Y, qx_in, fee=0.0): + """ + Swap x-in (amount qx_in) for y-out on a constant-product pool with fee. + Fee taken on input (Uniswap-style): only (1-fee)*qx_in enters the invariant. + Returns (dy_out, X_new, Y_new, fee_value). + """ + assert X > 0 and Y > 0 and qx_in >= 0 and 0 <= fee < 1 + k = X * Y + effective_in = qx_in * (1 - fee) + X_new = X + effective_in + Y_new = k / X_new + dy_out = Y - Y_new + fee_value = qx_in * fee # in x-units + return dy_out, X_new, Y_new, fee_value + +def cp_avg_y_per_x(X, Y, qx_in, fee=0.0): + if qx_in == 0: + return cp_price(X, Y) + dy, *_ = cp_trade_y_out(X, Y, qx_in, fee) + return dy / qx_in + +def lmsr_calibrate_b(X, Y, mode="thin_side", factor=0.5): + """ + Calibrate LMSR 'b' to match CP local log-price stiffness at the current state. + - thin_side: use min(X, Y) * factor where factor≈0.5 matches ds/dq at that side + - x_side: use X * factor + - y_side: use Y * factor (useful if measuring q in y-space) + For our use (q measured in x), 'x_side' or 'thin_side' with factor=0.5 are typical. + """ + if mode == "thin_side": + return max(1e-12, factor * min(X, Y)) + elif mode == "x_side": + return max(1e-12, factor * X) + elif mode == "y_side": + return max(1e-12, factor * Y) + else: + raise ValueError("mode must be one of ['thin_side','x_side','y_side']") + +def lmsr_y_out(qx_in, p0, b, fee=0.0): + """ + Symmetric 2-asset LMSR approximation in x-measure: + - We model log-price s moving linearly in q_x: ds/dq_x = 1/b + - Then price path p(q) = p0 * exp(q/b) + - Cumulative y received for x-in q is integral of p(q) dq: y(q) = p0 * b * (exp(q/b) - 1) + Fee applied to input reduces effective q: q_eff = q * (1 - fee) + Returns dy_out (in y-units) and fee_value (in x-units) + """ + assert qx_in >= 0 and b > 0 and p0 > 0 and 0 <= fee < 1 + q_eff = qx_in * (1 - fee) + dy = p0 * b * (math.exp(q_eff / b) - 1.0) + fee_value = qx_in * fee + return dy, fee_value + +def lmsr_avg_y_per_x(qx_in, p0, b, fee=0.0): + if qx_in == 0: + return p0 + dy, _ = lmsr_y_out(qx_in, p0, b, fee) + return dy / qx_in + + +# --------------------------- +# Static comparison utilities +# --------------------------- + +def static_compare(X, Y, fee_cp=0.0005, fee_lmsr=0.0005, b=None, b_mode="x_side", b_factor=0.5, + q_grid=None): + """ + Compute average execution (y per x), penalties vs spot, and welfare gaps over a grid of trade sizes. + Returns dict with arrays. + """ + p0 = cp_price(X, Y) + if b is None: + b = lmsr_calibrate_b(X, Y, mode=b_mode, factor=b_factor) + + if q_grid is None: + # span from tiny to a meaningful fraction of X + q_grid = np.geomspace(1e-9 * X, 0.3 * X, 60) + + avg_cp = np.array([cp_avg_y_per_x(X, Y, q, fee_cp) for q in q_grid]) + avg_lm = np.array([lmsr_avg_y_per_x(q, p0, b, fee_lmsr) for q in q_grid]) + + # Slippage penalty vs spot p0 (in y-per-x) + pen_cp = p0 - avg_cp + pen_lm = p0 - avg_lm + + # Taker welfare difference: LMSR minus CP (positive means LMSR better for taker) + welfare_gap = avg_lm - avg_cp + + return { + "q_grid": q_grid, + "avg_cp": avg_cp, + "avg_lm": avg_lm, + "pen_cp": pen_cp, + "pen_lm": pen_lm, + "welfare_gap": welfare_gap, + "p0": p0, + "b": b + } + + +# --------------------------- +# Imbalance sweep +# --------------------------- + +def sweep_imbalance(X, rho_grid, fee_cp=0.0005, fee_lmsr=0.0005, q_frac=0.01, + b_mode="thin_side", b_factor=0.5): + """ + For a fixed X (x-reserve), sweep imbalance rho = Y/X and measure welfare advantage at a chosen trade size q = q_frac * X. + Returns arrays over rho. + """ + q = q_frac * X + wgap = [] + pen_cp_list, pen_lm_list, p0_list, b_list = [], [], [], [] + for rho in rho_grid: + Y = rho * X + res = static_compare(X, Y, fee_cp, fee_lmsr, b=None, b_mode=b_mode, b_factor=b_factor, q_grid=np.array([q])) + wgap.append(res["welfare_gap"][0]) + pen_cp_list.append(res["pen_cp"][0]) + pen_lm_list.append(res["pen_lm"][0]) + p0_list.append(res["p0"]) + b_list.append(res["b"]) + return { + "rho_grid": rho_grid, + "q": q, + "welfare_gap": np.array(wgap), + "pen_cp": np.array(pen_cp_list), + "pen_lm": np.array(pen_lm_list), + "p0": np.array(p0_list), + "b": np.array(b_list) + } + + +# --------------------------- +# Sequential Monte Carlo simulation +# --------------------------- + +def sample_trade_sizes(n, X, dist="lognormal", mean_frac=0.005, std_frac=0.01, seed=None): + """ + Return non-negative trade sizes q in x-units. + - lognormal: parameters chosen so mean ≈ mean_frac*X + - uniform: U(0, 2*mean_frac*X) + """ + rng = np.random.default_rng(seed) + if dist == "lognormal": + mean = mean_frac * X + std = std_frac * X + # Map mean/std to lognormal parameters + # mean = exp(mu + sigma^2/2), var = (exp(sigma^2)-1)exp(2mu+sigma^2) + # Let CV = std/mean => sigma^2 = ln(1+CV^2), mu = ln(mean) - sigma^2/2 + cv = std / max(1e-12, mean) + sigma2 = math.log(1 + cv * cv) + sigma = math.sqrt(sigma2) + mu = math.log(max(1e-12, mean)) - sigma2 / 2 + q = rng.lognormal(mean=mu, sigma=sigma, size=n) + return q + elif dist == "uniform": + return rng.uniform(0, 2 * mean_frac * X, size=n) + else: + raise ValueError("Unknown dist") + + +def sequential_sim(X0, Y0, n_trades=2000, direction_bias=0.6, + fee_cp=0.0005, fee_lmsr=0.0005, + b_mode="thin_side", b_factor=0.5, + dist="lognormal", mean_frac=0.005, std_frac=0.01, seed=42): + """ + Run a sequential simulation where each trade either buys y with x (prob=direction_bias) + or buys x with y (prob=1-direction_bias). We compare CP vs LMSR per-trade and accumulate: + - taker surplus difference (y-per-x times q, converted to y-units) + - fee revenue for LPs (x- or y-denominated; we also track value at spot) + - pool state evolution (for CP only; LMSR is state-less under this approximation) + Note: We approximate LMSR as state-less with ds/dq = 1/b, b recalibrated each step to the thin side. + """ + rng = np.random.default_rng(seed) + X_cp, Y_cp = X0, Y0 + p_history = [] + b_history = [] + + taker_y_advantage = 0.0 + lp_fee_x_cp = 0.0 + lp_fee_x_lm = 0.0 + + # Track realized slippage stats at fixed snapshot prices for comparability + for t in range(n_trades): + p_cp = cp_price(X_cp, Y_cp) + p_history.append(p_cp) + + # Calibrate LMSR b to thin side each step + b = lmsr_calibrate_b(X_cp, Y_cp, mode=b_mode, factor=b_factor) + b_history.append(b) + + # Decide direction: True => use x to buy y; False => use y to buy x + buy_y = rng.random() < direction_bias + + if buy_y: + # Draw trade size in x-units + qx = float(sample_trade_sizes(1, X_cp, dist, mean_frac, std_frac, seed=rng.integers(1e9))[0]) + + # CP execution + dy_cp, X_cp_new, Y_cp_new, fee_x_cp = cp_trade_y_out(X_cp, Y_cp, qx, fee_cp) + + # LMSR execution (approximate with current spot p_cp) + dy_lm, fee_x_lm = lmsr_y_out(qx, p_cp, b, fee_lmsr) + + # Update CP state + X_cp, Y_cp = X_cp_new, Y_cp_new + + # Accumulate + taker_y_advantage += (dy_lm - dy_cp) + lp_fee_x_cp += fee_x_cp + lp_fee_x_lm += fee_x_lm + + else: + # Buy x with y. Mirror the model by symmetry: + # We'll convert problem by swapping roles of (x,y) and using same formulas. + # Draw trade size in y-units proportional to Y + qy = float(sample_trade_sizes(1, Y_cp, dist, mean_frac, std_frac, seed=rng.integers(1e9))[0]) + + # CP: y-in, x-out + # Symmetric function by swapping X<->Y and interpreting result + dy_dummy, Y_new, X_new, fee_y_cp = cp_trade_y_out(Y_cp, X_cp, qy, fee_cp) # returns x-out in "dy_dummy" + dx_cp = dy_dummy # interpret as x-out + X_cp, Y_cp = X_new, Y_new + + # LMSR: state-less approx using ds/dq_y = 1/b_y; we map via price and symmetry. + # For buy-x with y-in, average x per y uses 1/p along exponential path in y-space. + # For simplicity, we mirror by computing y-per-x with LMSR at price p, then invert locally: + # Expected x received ≈ qy / p * (e^{(qy_eff/b_y)/?} - 1)/((qy)/?) ... too detailed. + # Instead, use symmetry by swapping labels and reusing the function with p_inv = 1/p. + p_inv = 1.0 / p_cp + b_y = lmsr_calibrate_b(X_cp, Y_cp, mode="y_side", factor=b_factor) + dx_lm, fee_y_lm = lmsr_y_out(qy, p_inv, b_y, fee_lmsr) # gives x-out + + # Taker advantage in x-units; convert to y-units for aggregation using pre-trade spot + taker_y_advantage += (dx_lm - dx_cp) * p_cp # multiply by p to convert x to y at current spot + lp_fee_x_cp += fee_y_cp * p_inv # convert y-fee to x using p_inv + lp_fee_x_lm += fee_y_lm * p_inv + + return { + "taker_y_advantage": taker_y_advantage, + "lp_fee_x_cp": lp_fee_x_cp, + "lp_fee_x_lm": lp_fee_x_lm, + "p_history": np.array(p_history), + "b_history": np.array(b_history), + "final_state_cp": (X_cp, Y_cp) + } + + +# --------------------------- +# Plotting helpers +# --------------------------- + +def plot_static(res): + q = res["q_grid"] + p0 = res["p0"] + plt.figure(figsize=(10, 6)) + plt.plot(q, res["avg_cp"], label="CP avg y-per-x") + plt.plot(q, res["avg_lm"], label="LMSR avg y-per-x") + plt.axhline(p0, color="gray", linestyle="--", alpha=0.6, label="Spot p0") + plt.xscale("log") + plt.title("Average execution (y per x) vs trade size") + plt.xlabel("q_x (trade size)") + plt.ylabel("y per x") + plt.legend() + plt.tight_layout() + + plt.figure(figsize=(10, 6)) + plt.plot(q, res["welfare_gap"] / res["p0"] * 1e4) + plt.xscale("log") + plt.axhline(0, color="gray", linestyle="--") + plt.title("Taker welfare advantage LMSR−CP (in bp of price)") + plt.xlabel("q_x (trade size)") + plt.ylabel("bp") + plt.tight_layout() + +def plot_imbalance(sweep): + rho = sweep["rho_grid"] + plt.figure(figsize=(10, 6)) + plt.plot(rho, sweep["welfare_gap"] / sweep["p0"] * 1e4) + plt.axhline(0, color="gray", linestyle="--") + plt.title(f"LMSR−CP taker advantage vs imbalance (q={sweep['q']:.4g}·X)") + plt.xlabel("rho = Y/X (imbalance)") + plt.ylabel("bp of price") + plt.tight_layout() + + +# --------------------------- +# Demo main +# --------------------------- + +def demo(): + # Baseline pool + X, Y = 10_000.0, 10_000.0 + fee_cp = 0.0005 # 5 bp + fee_lmsr = 0.0005 # 5 bp + b_mode = "thin_side" + b_factor = 0.5 + + # Static comparison across trade sizes + res = static_compare( + X, Y, fee_cp=fee_cp, fee_lmsr=fee_lmsr, + b=None, b_mode=b_mode, b_factor=b_factor + ) + print(f"Spot p0={res['p0']:.6f}, calibrated b={res['b']:.6f}") + plot_static(res) + + # Imbalance sweep + rho_grid = np.linspace(0.2, 5.0, 60) # from thin-y to thin-x + sw = sweep_imbalance( + X, rho_grid, fee_cp=fee_cp, fee_lmsr=fee_lmsr, + q_frac=0.01, b_mode=b_mode, b_factor=b_factor + ) + plot_imbalance(sw) + + # Sequential simulation + sim = sequential_sim( + X0=X, Y0=Y, n_trades=2000, direction_bias=0.6, + fee_cp=fee_cp, fee_lmsr=fee_lmsr, + b_mode=b_mode, b_factor=b_factor, + dist="lognormal", mean_frac=0.005, std_frac=0.01, seed=7 + ) + adv_bp_equiv = sim["taker_y_advantage"] / (X + Y / res["p0"]) * 1e4 + print(f"Sequential sim taker Y-advantage (absolute): {sim['taker_y_advantage']:.6f} y-units") + print(f"LP fee (CP) in x-units: {sim['lp_fee_x_cp']:.6f}") + print(f"LP fee (LMSR) in x-units: {sim['lp_fee_x_lm']:.6f}") + + plt.show() + + +if __name__ == "__main__": + demo() diff --git a/script/DeployMock.sol b/script/DeployMock.sol new file mode 100644 index 0000000..53c7c31 --- /dev/null +++ b/script/DeployMock.sol @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "forge-std/Script.sol"; +import "forge-std/console2.sol"; +import "@abdk/ABDKMath64x64.sol"; +import "../test/MockERC20.sol"; +import "../src/IPartyPool.sol"; +import "../src/PartyPool.sol"; + +contract DeployMock is Script { + + // private key 0x4bbbf85ce3377467afe5d46f804f221813b2bb87f24d81f60f1fcdbf7cbf4356 + address constant devAccount7 = 0x14dC79964da2C08b23698B3D3cc7Ca32193d9955; + + function run() public { + vm.startBroadcast(deployer); + + // create mock tokens + usxd = new MockERC20('Joke Currency', 'USXD', 6); + fusd = new MockERC20('Fake USD', 'FUSD', 6); + dive = new MockERC20('DAI Virtually Equal', 'DIVE', 18); + + string memory name = 'Mock Pool'; + string memory symbol = 'MP'; + address[] memory tokens = new address[](3); + tokens[0] = address(usxd); + tokens[1] = address(fusd); + tokens[2] = address(dive); + uint256[] memory _bases = new uint256[](3); + _bases[0] = 6; + _bases[1] = 6; + _bases[2] = 18; + int128 _tradeFrac = ABDKMath64x64.divu(1, 10); + int128 _targetSlippage = ABDKMath64x64.divu(1,10000); + uint256 _feePpm = 100; + + IPartyPool pool = new PartyPool(); + bytes memory args = abi.encode(name, symbol, tokens, _bases, _tradeFrac, _targetSlippage, _feePpm); + bytes memory deployCode = abi.encodePacked(bytecode,args); + vm.etch(pool, deployCode); + + console2.log('PartyPool', pool); + + // initial mint + mintAll(pool, 10_000); + IPartyPool(pool).mint(deployer, deployer, 0, 0); + + console2.log('USXD', address(usxd)); + console2.log('FUSD', address(fusd)); + console2.log('DIVE', address(dive)); + + // give tokens to dev7 + mintAll(devAccount7, 1_000_000); + + vm.stopBroadcast(); + } + + address constant deployer = address(0x472358699872673459876); // anything + + MockERC20 private usxd; + MockERC20 private fusd; + MockERC20 private dive; + + function mintAll(address who, uint256 amount) internal { + usxd.mint(who, amount * 1e6); + fusd.mint(who, amount * 1e6); + dive.mint(who, amount * 1e18); + } + +} diff --git a/src/IPartyFlashCallback.sol b/src/IPartyFlashCallback.sol new file mode 100644 index 0000000..bb7cb4d --- /dev/null +++ b/src/IPartyFlashCallback.sol @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +interface IPartyFlashCallback { + function partyFlashCallback(uint256[] memory loanAmounts, uint256[] memory repaymentAmounts, bytes calldata data) external; +} diff --git a/src/IPartyPool.sol b/src/IPartyPool.sol new file mode 100644 index 0000000..3948158 --- /dev/null +++ b/src/IPartyPool.sol @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; + +/// @title PartyPool - LMSR-backed multi-asset pool with LP ERC20 token +/// @notice Uses LMSRStabilized library; stores per-token uint bases to convert to/from 64.64 fixed point. +/// - Caches qInternal[] (int128 64.64) and cachedUintBalances[] to minimize balanceOf() calls. +/// - swap and swapToLimit mimic core lib; mint/burn call updateForProportionalChange() and manage LP tokens. +interface IPartyPool is IERC20Metadata { + // All int128's are ABDKMath64x64 format + + // Events + + event Mint(address payer, address indexed receiver, uint256[] amounts, uint256 lpMinted); + + event Burn(address payer, address indexed receiver, uint256[] amounts, uint256 lpBurned); + + event Swap( + address payer, + address indexed receiver, + address indexed tokenIn, + address indexed tokenOut, + uint256 amountIn, + uint256 amountOut + ); + + /// @notice Emitted when a single-token swapMint is executed. + /// Records payer/receiver, input token index, gross transfer (net+fee), net input and fee taken. + event SwapMint( + address indexed payer, + address indexed receiver, + uint256 indexed inputTokenIndex, + uint256 grossTransfer, // total tokens transferred (net + fee) + uint256 netInput, // net input credited to swaps (after fee) + uint256 feeTaken // fee taken (ceil) + ); + + /// @notice Emitted when a burnSwap is executed. + /// Records payer/receiver, target token index and the uint payout sent to the receiver. + event BurnSwap( + address indexed payer, + address indexed receiver, + uint256 indexed targetTokenIndex, + uint256 payoutUint + ); + + + // Immutable pool configuration (public getters) + function tokens(uint256) external view returns (address); // get single token + function numTokens() external view returns (uint256); + function allTokens() external view returns (address[] memory); + function tradeFrac() external view returns (int128); // ABDK 64x64 + function targetSlippage() external view returns (int128); // ABDK 64x64 + function swapFeePpm() external view returns (uint256); + function tokenAddressToIndexPlusOne(address) external view returns (uint); + + // Initialization / Mint / Burn (LP token managed) + + /// @notice Calculate the proportional deposit amounts required for a given LP token amount + /// @param lpTokenAmount The amount of LP tokens desired + /// @return depositAmounts Array of token amounts to deposit (rounded up) + function computeMintAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory depositAmounts); + + /// @notice Proportional mint (or initial supply if first call). + /// For initial supply: assumes tokens have already been transferred to the pool + /// For subsequent mints: payer must approve tokens beforehand, receiver gets the LP tokens + /// @param payer address that provides the input tokens (ignored for initial deposit) + /// @param receiver address that receives the LP tokens + /// @param lpTokenAmount desired amount of LP tokens to mint (ignored for initial deposit) + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external; + + /// @notice Calculate the proportional withdrawal amounts for a given LP token amount + /// @param lpTokenAmount The amount of LP tokens to burn + /// @return withdrawAmounts Array of token amounts to withdraw (rounded down) + function computeBurnAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts); + + /// @notice Burn LP tokens and withdraw the proportional basket to receiver. + /// Payer must own the LP tokens; withdraw amounts are computed from current proportions. + /// @param payer address that provides the LP tokens to burn + /// @param receiver address that receives the withdrawn tokens + /// @param lpAmount amount of LP tokens to burn (proportional withdrawal) + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline) external; + + + // Swaps + function swap( + address payer, + address receiver, + uint256 i, + uint256 j, + uint256 maxAmountIn, + int128 limitPrice, + uint256 deadline + ) external returns (uint256 amountIn, uint256 amountOut); + + function swapToLimit( + address payer, + address receiver, + uint256 i, + uint256 j, + int128 limitPrice, + uint256 deadline + ) external returns (uint256 amountInUsed, uint256 amountOut); + + /// @notice Single-token mint: deposit a single token, charge swap-LMSR cost, and mint LP. + /// @param payer who transfers the input token + /// @param receiver who receives the minted LP tokens + /// @param i index of the input token + /// @param maxAmountIn maximum uint token input (inclusive of fee) + /// @param deadline optional deadline + /// @return lpMinted actual LP minted (uint) + function swapMint( + address payer, + address receiver, + uint256 i, + uint256 maxAmountIn, + uint256 deadline + ) external returns (uint256 lpMinted); + + /// @notice Burn LP tokens then swap the redeemed proportional basket into a single asset `i` and send to receiver. + /// @param payer who burns LP tokens + /// @param receiver who receives the single asset + /// @param lpAmount amount of LP tokens to burn + /// @param i index of target asset to receive + /// @param deadline optional deadline + /// @return amountOutUint uint amount of asset i sent to receiver + function burnSwap( + address payer, + address receiver, + uint256 lpAmount, + uint256 i, + uint256 deadline + ) external returns (uint256 amountOutUint); + + /// @notice Receive token0 and/or token1 and pay it back, plus a fee, in the callback + /// @dev The caller of this method receives a callback in the form of IPartyFlashCallback#partyFlashCallback + /// @param recipient The address which will receive the token amounts + /// @param amounts The amount of each token to send + /// @param data Any data to be passed through to the callback + function flash( + address recipient, + uint256[] memory amounts, + bytes calldata data + ) external; + +} diff --git a/src/LMSRStabilized.sol b/src/LMSRStabilized.sol new file mode 100644 index 0000000..e7d75f5 --- /dev/null +++ b/src/LMSRStabilized.sol @@ -0,0 +1,1138 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "forge-std/console2.sol"; +import "@abdk/ABDKMath64x64.sol"; + +/// @notice Stabilized LMSR library with incremental exp(z) caching for gas efficiency. +/// - Stores b (64.64), M (shift), Z = sum exp(z_i), z[i] = (q_i / b) - M +/// - Caches e[i] = exp(z[i]) so we avoid recomputing exp() for every asset on each trade. +/// - Provides closed-form ΔC on deposit, amount-out for asset->asset, +/// and incremental applyDeposit/applyWithdraw that update e[i] and Z in O(1). +library LMSRStabilized { + using ABDKMath64x64 for int128; + + struct State { + uint256 nAssets; + int128 kappa; // liquidity parameter κ (64.64 fixed point) + int128[] qInternal; // cached internal balances in 64.64 fixed-point format + } + + /* -------------- + Initialization + -------------- */ + + /// @notice Initialize the stabilized state from internal balances qInternal (int128[]) + /// qInternal must be normalized to 64.64 fixed-point format. + function init( + State storage s, + int128[] memory initialQInternal, + int128 tradeFrac, + int128 targetSlippage + ) internal { + s.nAssets = initialQInternal.length; + + // Initialize qInternal cache + if (s.qInternal.length != initialQInternal.length) { + s.qInternal = new int128[](initialQInternal.length); + } + for (uint i = 0; i < initialQInternal.length; ) { + s.qInternal[i] = initialQInternal[i]; + unchecked { i++; } + } + + int128 total = _computeSizeMetric(s.qInternal); + console2.log("total (internal 64.64)"); + console2.logInt(total); + require(total > int128(0), "LMSR: total zero"); + + console2.log("LMSR.init: start"); + console2.log("nAssets", s.nAssets); + console2.log("qInternal.length", s.qInternal.length); + + // Compute kappa from slippage parameters + setKappaFromSlippage(s, tradeFrac, targetSlippage); + console2.log("kappa (64x64)"); + console2.logInt(s.kappa); + require(s.kappa > int128(0), "LMSR: kappa>0"); + + console2.log("LMSR.init: done"); + } + + /* -------------------- + View helpers + -------------------- */ + + /// @notice Cost C(q) = b * (M + ln(Z)) + function cost(State storage s) internal view returns (int128) { + int128 b = _computeB(s); + (int128 M, int128 Z) = _computeMAndZ(b, s.qInternal); + int128 lnZ = _ln(Z); + int128 inner = M.add(lnZ); + int128 c = b.mul(inner); + return c; + } + + + /* --------- + Swapping + --------- */ + + /// @notice Closed-form asset-i -> asset-j amountOut in 64.64 fixed-point format (fee-free kernel) + /// Uses the closed-form two-asset LMSR formula (no fees in kernel): + /// y = b * ln(1 + r0 * (1 - exp(-a / b))) + /// where r0 = e_i / e_j. + /// + /// This variant accepts an additional `limitPrice` (64.64) which represents the + /// maximum acceptable marginal price (p_i / p_j). If the marginal price would + /// exceed `limitPrice` before the requested `a` is fully consumed, the input + /// `a` is truncated to the value that makes the marginal price equal `limitPrice`. + /// + /// NOTE: Kernel is fee-free; fees should be handled by the wrapper/token layer. + /// + /// @param i Index of input asset + /// @param j Index of output asset + /// @param a Amount of input asset (in int128 format, 64.64 fixed-point) + /// @param limitPrice Maximum acceptable price ratio (64.64). If <= current price, this call reverts. + /// @return amountIn Actual amount of input asset used (may be less than `a` if limited by price) + /// @return amountOut Amount of output asset j in 64.64 fixed-point format + function swapAmountsForExactInput( + State storage s, + uint256 i, + uint256 j, + int128 a, + int128 limitPrice + ) internal view returns (int128 amountIn, int128 amountOut) { + require(i < s.nAssets && j < s.nAssets, "LMSR: idx"); + + // Initialize amountIn to full amount (will be adjusted if limit price is hit) + amountIn = a; + + // Compute b and ensure positivity before deriving invB + int128 b = _computeB(s); + require(b > int128(0), "LMSR: b<=0"); + + // Precompute reciprocal of b to avoid repeated divisions + int128 invB = ABDKMath64x64.div(ONE, b); + + // Guard: output asset must have non-zero effective weight to avoid degenerate/div-by-zero-like conditions + require(s.qInternal[j] > int128(0), "LMSR: e_j==0"); + + // Compute r0 = exp((q_i - q_j) / b) directly using invB + int128 r0 = _exp(s.qInternal[i].sub(s.qInternal[j]).mul(invB)); + require(r0 > int128(0), "LMSR: r0<=0"); // equivalent to e_j > 0 check + + // If a positive limitPrice is given, determine whether the full `a` would + // push the marginal price p_i/p_j beyond the limit; if so, truncate `a`. + // Marginal price ratio evolves as r(t) = r0 * exp(t/b) (since e_i multiplies by exp(t/b)) + if (limitPrice > int128(0)) { + console2.log("\n=== LimitPrice Logic Debug ==="); + console2.log("Received limitPrice (64x64):"); + console2.logInt(limitPrice); + + console2.log("Current price ratio r0 (e_i/e_j, 64x64):"); + console2.logInt(r0); + + // r0 must be positive; if r0 == 0 then no risk of exceeding limit by increasing r. + require(r0 >= int128(0), "LMSR: r0<0"); + if (r0 == int128(0)) { + console2.log("r0 == 0 (input asset has zero weight), no limit truncation needed"); + } else { + // If limitPrice <= current price, we revert (caller must choose a limit > current price to allow any fill) + if (limitPrice <= r0) { + console2.log("Limit price is <= current price: reverting"); + revert("LMSR: limitPrice <= current price"); + } + + // Compute a_limit directly from ln(limit / r0): a_limit = b * ln(limit / r0) + int128 ratioLimitOverR0 = limitPrice.div(r0); + console2.log("limitPrice/r0 (64x64):"); + console2.logInt(ratioLimitOverR0); + require(ratioLimitOverR0 > int128(0), "LMSR: ratio<=0"); + + int128 aLimitOverB = _ln(ratioLimitOverR0); // > 0 + console2.log("ln(limitPrice/r0) (64x64):"); + console2.logInt(aLimitOverB); + + // aLimit = b * aLimitOverB + int128 aLimit64 = b.mul(aLimitOverB); + console2.log("aLimit in 64x64 format:"); + console2.logInt(aLimit64); + + // If computed aLimit is less than the requested a, use the truncated value. + if (aLimit64 < a) { + console2.log("TRUNCATING: a reduced from 64.64 value"); + console2.logInt(a); + console2.log("to 64.64 value"); + console2.logInt(aLimit64); + amountIn = aLimit64; // Store the truncated input amount + a = aLimit64; // Use truncated amount for calculations + } else { + console2.log("Not truncating: aLimit64 >= a"); + } + } + } + + // compute a/b safely and guard against very large arguments to exp() + int128 aOverB = a.mul(invB); + // Protect exp from enormous inputs (consistent with recenter thresholds) + require(aOverB <= EXP_LIMIT, "LMSR: a/b too large (would overflow exp)"); + + console2.log("\n=== AmountOut Calculation Debug ==="); + console2.log("Input amount (64.64):"); + console2.logInt(a); + console2.log("a/b (64x64):"); + console2.logInt(aOverB); + + // Use the closed-form fee-free formula: + // y = b * ln(1 + r0 * (1 - exp(-a/b))) + console2.log("r0_for_calc (e_i/e_j):"); + console2.logInt(r0); + + int128 expNeg = _exp(aOverB.neg()); // exp(-a/b) + console2.log("exp(-a/b):"); + console2.logInt(expNeg); + + int128 oneMinusExpNeg = ONE.sub(expNeg); + console2.log("1 - exp(-a/b):"); + console2.logInt(oneMinusExpNeg); + + int128 inner = ONE.add(r0.mul(oneMinusExpNeg)); + console2.log("inner = 1 + r0 * (1 - exp(-a/b)):"); + console2.logInt(inner); + + // If inner <= 0 then cap output to the current balance q_j (cannot withdraw more than q_j) + if (inner <= int128(0)) { + console2.log("WARNING: inner <= 0, capping output to balance q_j"); + int128 qj64 = s.qInternal[j]; + console2.log("Capped output (64.64):"); + console2.logInt(qj64); + return (amountIn, qj64); + } + + int128 lnInner = _ln(inner); + console2.log("ln(inner):"); + console2.logInt(lnInner); + + int128 b_lnInner = b.mul(lnInner); + console2.log("b*ln(inner):"); + console2.logInt(b_lnInner); + + amountOut = b_lnInner; + console2.log("amountOut = b*ln(inner):"); + console2.logInt(amountOut); + + console2.log("amountOut (final 64.64 amount):"); + console2.logInt(amountOut); + + // Safety check + if (amountOut <= 0) { + console2.log("WARNING: x64 <= 0, returning 0"); + return (0, 0); + } + } + + + /// @notice Maximum input/output pair possible when swapping from asset i to asset j + /// given a maximum acceptable price ratio (p_i/p_j). + /// Returns the input amount that would drive the marginal price to the limit (amountIn) + /// and the corresponding output amount (amountOut). If the output would exceed the + /// j-balance, amountOut is capped and amountIn is solved for the capped output. + /// + /// @param i Index of input asset + /// @param j Index of output asset + /// @param limitPrice Maximum acceptable price ratio (64.64) + /// @return amountIn Maximum input amount in 64.64 fixed-point format that reaches the price limit + /// @return amountOut Corresponding maximum output amount in 64.64 fixed-point format + function swapAmountsForPriceLimit( + State storage s, + uint256 i, + uint256 j, + int128 limitPrice + ) internal view returns (int128 amountIn, int128 amountOut) { + require(i < s.nAssets && j < s.nAssets, "LMSR: idx"); + require(limitPrice > int128(0), "LMSR: limitPrice <= 0"); + + // Compute b and ensure positivity before deriving invB + int128 b = _computeB(s); + require(b > int128(0), "LMSR: b<=0"); + + // Precompute reciprocal of b to avoid repeated divisions + int128 invB = ABDKMath64x64.div(ONE, b); + + // Guard: output asset must have non-zero effective weight to avoid degenerate/div-by-zero-like conditions + require(s.qInternal[j] > int128(0), "LMSR: e_j==0"); + + // Compute r0 = exp((q_i - q_j) / b) directly using invB + int128 r0 = _exp(s.qInternal[i].sub(s.qInternal[j]).mul(invB)); + + console2.log("\n=== Max Input/Output Calculation ==="); + console2.log("Limit price (64x64):"); + console2.logInt(limitPrice); + console2.log("Current price ratio r0 (e_i/e_j, 64x64):"); + console2.logInt(r0); + + // Mirror swapAmountsForExactInput behavior: treat invalid r0 as an error condition. + // Revert if r0 is non-positive (no finite trade under a price limit). + require(r0 > int128(0), "LMSR: r0<=0"); + + // If current price already exceeds or equals limit, revert the same way swapAmountsForExactInput does. + if (r0 >= limitPrice) { + console2.log("Limit price is <= current price: reverting"); + revert("LMSR: limitPrice <= current price"); + } + + // Calculate the price change factor: limitPrice/r0 + int128 priceChangeFactor = limitPrice.div(r0); + console2.log("Price change factor (limitPrice/r0):"); + console2.logInt(priceChangeFactor); + + // ln(priceChangeFactor) gives us the maximum allowed delta in the exponent + int128 maxDeltaExponent = _ln(priceChangeFactor); + console2.log("Max delta exponent ln(priceChangeFactor):"); + console2.logInt(maxDeltaExponent); + + // Maximum input capable of reaching the price limit: + // x_max = b * ln(limitPrice / r0) + int128 amountInMax = b.mul(maxDeltaExponent); + console2.log("Max input to reach limit (64.64):"); + console2.logInt(amountInMax); + + // The maximum output y corresponding to that input: + // y = b * ln(1 + (e_i/e_j) * (1 - exp(-x_max/b))) + int128 expTerm = ONE.sub(_exp(maxDeltaExponent.neg())); + console2.log("1 - exp(-maxDeltaExponent):"); + console2.logInt(expTerm); + + int128 innerTerm = r0.mul(expTerm); + console2.log("e_i/e_j * expTerm:"); + console2.logInt(innerTerm); + + int128 lnTerm = _ln(ONE.add(innerTerm)); + console2.log("ln(1 + innerTerm):"); + console2.logInt(lnTerm); + + int128 maxOutput = b.mul(lnTerm); + console2.log("Max output (b * lnTerm):"); + console2.logInt(maxOutput); + + // Current balance of asset j (in 64.64) + int128 qj64 = s.qInternal[j]; + console2.log("Current j balance (64.64):"); + console2.logInt(qj64); + + // Initialize outputs to the computed maxima + amountIn = amountInMax; + amountOut = maxOutput; + + // If the calculated maximum output exceeds the balance, cap output and solve for input. + if (maxOutput > qj64) { + console2.log("Max output exceeds balance, capping to balance"); + amountOut = qj64; + + // Solve inverse relation for input given capped output: + // Given y = amountOut, let E = exp(y/b). Then + // 1 - exp(-a/b) = (E - 1) / r0 + // exp(-a/b) = 1 - (E - 1) / r0 = (r0 + 1 - E) / r0 + // a = -b * ln( (r0 + 1 - E) / r0 ) = b * ln( r0 / (r0 + 1 - E) ) + int128 E = _exp(amountOut.mul(invB)); // exp(y/b) + int128 rhs = r0.add(ONE).sub(E); // r0 + 1 - E + console2.log("E = exp(y/b):"); + console2.logInt(E); + console2.log("rhs = r0 + 1 - E:"); + console2.logInt(rhs); + + // If rhs <= 0 due to numerical issues, fall back to amountInMax + if (rhs <= int128(0)) { + console2.log("Numerical issue solving inverse; using amountInMax as fallback"); + amountIn = amountInMax; + } else { + amountIn = b.mul(_ln(r0.div(rhs))); + console2.log("Computed input required for capped output (64.64):"); + console2.logInt(amountIn); + } + } + + return (amountIn, amountOut); + } + + /// @notice Compute LP-size increase when minting from a single-token input using bisection only. + /// @dev Solve for α >= 0 such that: + /// a = α*q_i + sum_{j != i} x_j(α) + /// where x_j(α) is the input to swap i->j that yields y_j = α*q_j and + /// x_j = b * ln( r0_j / (r0_j + 1 - exp(y_j / b)) ), r0_j = exp((q_i - q_j)/b). + /// Bisection is used (no Newton) to keep implementation compact and gas-friendly. + function swapAmountsForMint( + State storage s, + uint256 i, + int128 a + ) internal view returns (int128 amountIn, int128 amountOut) { + require(i < s.nAssets, "LMSR: idx"); + require(a > int128(0), "LMSR: amount <= 0"); + + int128 b = _computeB(s); + require(b > int128(0), "LMSR: b<=0"); + int128 invB = ABDKMath64x64.div(ONE, b); + int128 S = _computeSizeMetric(s.qInternal); + + uint256 n = s.nAssets; + + // Precompute r0_j = exp((q_i - q_j) / b) for all j to avoid recomputing during search. + int128[] memory r0 = new int128[](n); + for (uint256 j = 0; j < n; ) { + r0[j] = _exp(s.qInternal[i].sub(s.qInternal[j]).mul(invB)); + unchecked { j++; } + } + + // convergence epsilon in Q64.64 (~1e-6) + int128 eps = ABDKMath64x64.divu(1, 1_000_000); + + // Helper inline: compute required input for given alpha (returns very large on failure) + // We'll inline the body where needed to avoid nested captures. + + // Find upper bound by doubling (start from reasonable guess a/S) + int128 low = int128(0); + int128 high; + if (S > int128(0)) { + high = ABDKMath64x64.div(a, S); // initial guess α ~ a / S + if (high < ONE) { + high = ONE; // at least 1.0 + } + } else { + // degenerate; treat as zero outcome + revert('LMSR: swapMint degenerate'); + } + + // Safety cap for alpha (prevent runaway doubling) + int128 alphaCap = ABDKMath64x64.fromUInt(1 << 20); + + // Doubling phase to ensure aRequired(high) >= a (or hit cap) + for (uint iter = 0; iter < 64; ) { + // compute aRequired at current high + int128 alpha = high; + int128 sumX = int128(0); + bool fail = false; + + // loop j != i + for (uint256 j = 0; j < n; ) { + if (j != i) { + int128 yj = alpha.mul(s.qInternal[j]); // target output y_j = alpha * q_j + if (yj > int128(0)) { + int128 expArg = yj.mul(invB); + // Guard exp arg + if (expArg > EXP_LIMIT) { fail = true; break; } + int128 E = _exp(expArg); // exp(yj / b) + int128 rhs = r0[j].add(ONE).sub(E); // r0 + 1 - E + if (rhs <= int128(0)) { fail = true; break; } + int128 numer = r0[j].div(rhs); + if (numer <= int128(0)) { fail = true; break; } + int128 xj = b.mul(_ln(numer)); + if (xj < int128(0)) { fail = true; break; } + sumX = sumX.add(xj); + } + } + unchecked { j++; } + } + + int128 aReq = fail ? int128(type(int128).max) : alpha.mul(s.qInternal[i]).add(sumX); + + if (aReq >= a || high >= alphaCap) { + break; + } + + // double high + high = high.mul(ABDKMath64x64.fromUInt(2)); + if (high > alphaCap) { high = alphaCap; } + unchecked { iter++; } + } + + // Bisection in [low, high] + int128 foundAlpha = low; + for (uint iter = 0; iter < 64; ) { + int128 mid = ABDKMath64x64.div(low.add(high), ABDKMath64x64.fromUInt(2)); + int128 alpha = mid; + int128 sumX = int128(0); + bool fail = false; + + for (uint256 j = 0; j < n; ) { + if (j != i) { + int128 yj = alpha.mul(s.qInternal[j]); + if (yj > int128(0)) { + int128 expArg = yj.mul(invB); + if (expArg > EXP_LIMIT) { fail = true; break; } + int128 E = _exp(expArg); + int128 rhs = r0[j].add(ONE).sub(E); + if (rhs <= int128(0)) { fail = true; break; } + int128 numer = r0[j].div(rhs); + if (numer <= int128(0)) { fail = true; break; } + int128 xj = b.mul(_ln(numer)); + if (xj < int128(0)) { fail = true; break; } + sumX = sumX.add(xj); + } + } + unchecked { j++; } + } + + int128 aReq = fail ? int128(type(int128).max) : alpha.mul(s.qInternal[i]).add(sumX); + + if (aReq > a) { + // mid requires more input than provided -> decrease alpha + high = mid; + } else { + // mid requires <= provided input -> alpha can be at least mid + low = mid; + } + + // convergence + if (high.sub(low) <= eps) { + foundAlpha = low; + break; + } + + // final iteration fallback + if (iter == 63) { + foundAlpha = low; + } + + unchecked { iter++; } + } + + // compute actual required input at foundAlpha (may be slightly <= a) + int128 alphaFinal = foundAlpha; + int128 sumXFinal = int128(0); + bool failFinal = false; + for (uint256 j = 0; j < n; ) { + if (j != i) { + int128 yj = alphaFinal.mul(s.qInternal[j]); + if (yj > int128(0)) { + int128 expArg = yj.mul(invB); + if (expArg > EXP_LIMIT) { failFinal = true; break; } + int128 E = _exp(expArg); + int128 rhs = r0[j].add(ONE).sub(E); + if (rhs <= int128(0)) { failFinal = true; break; } + int128 numer = r0[j].div(rhs); + if (numer <= int128(0)) { failFinal = true; break; } + int128 xj = b.mul(_ln(numer)); + if (xj < int128(0)) { failFinal = true; break; } + sumXFinal = sumXFinal.add(xj); + } + } + unchecked { j++; } + } + + if (failFinal) { + // Numerical failure -> signal zero outcome conservatively + return (int128(0), int128(0)); + } + + int128 aRequired = alphaFinal.mul(s.qInternal[i]).add(sumXFinal); + + // amountIn is actual consumed input (may be <= provided a) + amountIn = aRequired; + // amountOut is alpha * S (LP-equivalent increase) + amountOut = alphaFinal.mul(S); + + // If values are numerically zero (no meaningful trade) revert to avoid zero-mint edge case. + if (amountOut <= int128(0) || amountIn <= int128(0)) { + revert("LMSR: zero output"); + } + + return (amountIn, amountOut); + } + + /// @notice Compute single-asset payout when burning a proportional share alpha of the pool. + /// @dev Simulate q_after = (1 - alpha) * q, return the amount of asset `i` the burner + /// would receive after swapping each other asset's withdrawn portion into `i`. + /// For each j != i: + /// - wrapper holds a_j = alpha * q_j + /// - swap j->i with closed-form exact-input formula using the current q_local + /// - cap output to q_local[i] when necessary (solve inverse for input used) + /// Treat any per-asset rhs<=0 as "this asset contributes zero" (do not revert). + /// Revert only if the final single-asset payout is zero. + function swapAmountsForBurn( + State storage s, + uint256 i, + int128 alpha + ) internal view returns (int128 amountOut, int128 amountIn) { + require(i < s.nAssets, "LMSR: idx"); + require(alpha > int128(0) && alpha <= ONE, "LMSR: alpha"); + + int128 b = _computeB(s); + require(b > int128(0), "LMSR: b<=0"); + int128 invB = ABDKMath64x64.div(ONE, b); + + uint256 n = s.nAssets; + + // Size metric and burned size (amountIn returned) + int128 S = _computeSizeMetric(s.qInternal); + amountIn = alpha.mul(S); // total size-metric redeemed + + // Build q_local := q_after_burn = (1 - alpha) * q + int128[] memory qLocal = new int128[](n); + for (uint256 j = 0; j < n; ) { + qLocal[j] = s.qInternal[j].mul(ONE.sub(alpha)); + unchecked { j++; } + } + + // Start totalOut with direct portion of asset i redeemed + int128 totalOut = alpha.mul(s.qInternal[i]); + + // Track whether any non-zero contribution was produced + bool anyNonZero = (totalOut > int128(0)); + + // For each asset j != i, swap the withdrawn a_j := alpha * q_j into i + for (uint256 j = 0; j < n; ) { + if (j != i) { + int128 aj = alpha.mul(s.qInternal[j]); // wrapper-held withdrawn amount of j + if (aj > int128(0)) { + // expArg = aj / b + int128 expArg = aj.mul(invB); + + // Guard exp argument magnitude; if too large treat contribution as zero + if (expArg > EXP_LIMIT) { + // skip this asset's contribution (numerically unsafe) + unchecked { j++; } + continue; + } + + // r0_j = exp((q_local[j] - q_local[i]) / b) + int128 r0_j = _exp(qLocal[j].sub(qLocal[i]).mul(invB)); + + // closed-form amountOut candidate: + // y = b * ln(1 + r0 * (1 - exp(-a/b))) + int128 expNeg = _exp(expArg.neg()); // exp(-a/b) + int128 inner = ONE.add(r0_j.mul(ONE.sub(expNeg))); + + if (inner <= int128(0)) { + // treat as zero contribution from this asset + unchecked { j++; } + continue; + } + + int128 y = b.mul(_ln(inner)); + + // If computed y would exceed the available pool balance q_local[i], cap to q_local[i] + if (y > qLocal[i]) { + // Cap output to qLocal[i]; solve inverse for input used (amountInUsed) + // E = exp(y_cap / b) where y_cap = qLocal[i] + int128 E = _exp(qLocal[i].mul(invB)); + int128 rhs = r0_j.add(ONE).sub(E); // r0 + 1 - E + + if (rhs <= int128(0)) { + // numeric issue: treat as zero contribution + unchecked { j++; } + continue; + } + + // amountInUsed = b * ln( r0 / rhs ) + int128 amountInUsed = b.mul(_ln(r0_j.div(rhs))); + + // Update q_local: pool receives amountInUsed on asset j, and loses qLocal[i] + qLocal[j] = qLocal[j].add(amountInUsed); + // subtract capped output from qLocal[i] (becomes zero) + totalOut = totalOut.add(qLocal[i]); + qLocal[i] = int128(0); + anyNonZero = true; + unchecked { j++; } + continue; + } + + // Normal path: use full aj as input and y as output + // Update q_local accordingly: pool receives aj on j, and loses y on i + qLocal[j] = qLocal[j].add(aj); + qLocal[i] = qLocal[i].sub(y); + totalOut = totalOut.add(y); + anyNonZero = true; + } + } + unchecked { j++; } + } + + // If no asset contributed (totalOut == 0) treat as no-trade and revert + if (!anyNonZero || totalOut <= int128(0)) { + revert("LMSR: zero output"); + } + + amountOut = totalOut; + return (amountOut, amountIn); + } + + + /// @notice Updates the LMSR state after performing an asset-to-asset swap + /// Updates the internal qInternal cache with the new balances + /// @param i Index of input asset + /// @param j Index of output asset + /// @param amountIn Amount of input asset used (in int128 format, 64.64 fixed-point) + /// @param amountOut Amount of output asset provided (in int128 format, 64.64 fixed-point) + function applySwap( + State storage s, + uint256 i, + uint256 j, + int128 amountIn, + int128 amountOut + ) internal { + require(i < s.nAssets && j < s.nAssets, "LMSR: idx"); + require(amountIn > int128(0), "LMSR: amountIn <= 0"); + require(amountOut > int128(0), "LMSR: amountOut <= 0"); + + console2.log("\n=== Applying Swap ==="); + console2.log("Input asset:", i); + console2.log("Output asset:", j); + console2.log("Amount in (64.64):"); + console2.logInt(amountIn); + console2.log("Amount out (64.64):"); + console2.logInt(amountOut); + + // Update internal balances + s.qInternal[i] = s.qInternal[i].add(amountIn); + s.qInternal[j] = s.qInternal[j].sub(amountOut); + + console2.log("=== Swap Applied (qInternal updated) ===\n"); + } + + + /// @notice Update pool state for proportional mint/redeem operations + /// This maintains price neutrality by keeping q/b ratio constant + /// Updates the internal qInternal cache with the new balances + /// @param newQInternal New asset quantities after mint/redeem (64.64 format) + function updateForProportionalChange(State storage s, int128[] memory newQInternal) internal { + require(newQInternal.length == s.nAssets, "LMSR: length mismatch"); + + console2.log("LMSR.updateForProportionalChange: start"); + + // Compute new total for validation + int128 newTotal = _computeSizeMetric(newQInternal); + console2.log("new total"); + console2.logInt(newTotal); + + require(newTotal > int128(0), "LMSR: new total zero"); + + // With kappa formulation, b automatically scales with pool size + int128 newB = s.kappa.mul(newTotal); + console2.log("new effective b"); + console2.logInt(newB); + + // Update the cached qInternal with new values + for (uint i = 0; i < s.nAssets; ) { + s.qInternal[i] = newQInternal[i]; + unchecked { i++; } + } + + console2.log("LMSR.updateForProportionalChange: end"); + } + + /// @notice Price-share of asset i: exp(z_i) / Z (64.64) + function priceShare(State storage s, uint256 i) internal view returns (int128) { + int128 b = _computeB(s); + uint len = s.qInternal.length; + require(len > 0, "LMSR: no assets"); + + // Precompute reciprocal of b and perform a single pass that tracks M, Z, and e_i + int128 invB = ABDKMath64x64.div(ONE, b); + + // Initialize from the first element + int128 M = s.qInternal[0].mul(invB); + int128 Z = ONE; // exp(0) + int128 e_i_acc; + bool setEi; + + if (i == 0) { + e_i_acc = ONE; // exp(0) + setEi = true; + } + + for (uint idx = 1; idx < len; ) { + int128 yi = s.qInternal[idx].mul(invB); + if (yi <= M) { + // Add contribution under current center + int128 term = _exp(yi.sub(M)); + Z = Z.add(term); + if (idx == i) { + e_i_acc = term; + setEi = true; + } + } else { + // Rescale Z and any tracked e_i when center increases + int128 scale = _exp(M.sub(yi)); // == exp(-(yi - M)) + Z = Z.mul(scale).add(ONE); + if (setEi) { + e_i_acc = e_i_acc.mul(scale); + } + M = yi; + if (idx == i) { + e_i_acc = ONE; // exp(0) at the new center + setEi = true; + } + } + unchecked { idx++; } + } + + // If i was not seen in the loop (i == 0 handled above), ensure e_i_acc was set + if (!setEi) { + // Only possible when len == 1 and i != 0, guarded by caller invariants typically + // Fallback: compute directly (kept for completeness) + int128 yi = s.qInternal[i].mul(invB); + e_i_acc = _exp(yi.sub(M)); + } + + return e_i_acc.div(Z); + } + + /* -------------------- + Slippage -> b computation & resize-triggered rescale + -------------------- */ + + /// @notice Compute and set kappa from slippage parameters and current asset quantities + /// This should be called during initialization or when recalibrating the pool parameters + function setKappaFromSlippage( + State storage s, + int128 tradeFrac, + int128 targetSlippage + ) internal { + require(s.nAssets > 0, "LMSR: no assets"); + require(s.qInternal.length == s.nAssets, "LMSR: length mismatch"); + + int128 total = _computeSizeMetric(s.qInternal); + + // Detect degenerate "all balances equal" case: if every qInternal equals the first, + // prefer the equal-inventories closed-form to avoid taking the heterogeneous path. + bool allEqual = true; + int128 first = s.qInternal[0]; + for (uint i = 1; i < s.qInternal.length; ) { + if (s.qInternal[i] != first) { + allEqual = false; + break; + } + unchecked { i++; } + } + + int128 targetB; + if (allEqual) { + // All assets have identical internal balances -> use equal-case core explicitly. + targetB = _computeBFromSlippageCore(total, s.nAssets, tradeFrac, targetSlippage, true); + } else { + // Compute target b using representative per-asset q for improved numerical stability + targetB = _computeBFromSlippage(total, s.nAssets, tradeFrac, targetSlippage); + } + + // Numeric trace for debugging / verification + console2.log("setKappaFromSlippage: trace start"); + console2.log("Q (total, Q64.64):"); + console2.logInt(total); + console2.log("tradeFrac (f, Q64.64):"); + console2.logInt(tradeFrac); + console2.log("targetSlippage (s, Q64.64):"); + console2.logInt(targetSlippage); + console2.log("nAssets:"); + console2.logUint(s.nAssets); + console2.log("total (S(q), Q64.64):"); + console2.logInt(total); + console2.log("targetB (computed, Q64.64):"); + console2.logInt(targetB); + console2.log("setKappaFromSlippage: trace end"); + + // Compute kappa = b_target / S(q) + s.kappa = targetB.div(total); + require(s.kappa > int128(0), "LMSR: kappa<=0"); + + console2.log("Set kappa from slippage params:"); + console2.log("total"); + console2.logInt(total); + console2.log("targetB"); + console2.logInt(targetB); + console2.log("kappa"); + console2.logInt(s.kappa); + } + + /// @notice Public wrapper for computing b from slippage parameters. + /// Picks the degenerate closed-form when the heterogeneous invariants are not satisfied, + /// otherwise uses the heterogeneous derivation implemented in computeBFromSlippageCore. + function _computeBFromSlippage( + int128 q, // total assets + uint256 nAssets, + int128 tradeFrac, + int128 targetSlippage + ) internal pure returns (int128) { + // Quick sanity checks that decide whether the heterogeneous formula is applicable. + // If not, fall back to the closed-form equal-asset formula for stability. + int128 one = _one(); + int128 onePlusS = one.add(targetSlippage); + + int128 n64 = ABDKMath64x64.fromUInt(nAssets); + int128 nMinus1_64 = ABDKMath64x64.fromUInt(nAssets - 1); + + // If 1 + s >= n then heterogeneous formula degenerates; use equal-asset closed-form. + if (onePlusS >= n64) { + return _computeBFromSlippageCore(q, nAssets, tradeFrac, targetSlippage, true); + } + + // denom = n - (1+s) + int128 denom = n64.sub(onePlusS); + int128 prod = onePlusS.mul(nMinus1_64); + + // If prod <= 0 or denom >= prod then heterogeneous formula is not in its valid range. + if (!(prod > int128(0) && denom < prod)) { + return _computeBFromSlippageCore(q, nAssets, tradeFrac, targetSlippage, true); + } + + // Otherwise use the heterogeneous derivation. + return _computeBFromSlippageCore(q, nAssets, tradeFrac, targetSlippage, false); + } + + /// @notice Core implementation that computes b from slippage parameters. + /// If assumeEqual == true, uses the closed-form algebra for equal inventories. + /// Otherwise uses the general derivation (original heterogeneous formula). + function _computeBFromSlippageCore( + int128 q, // total assets + uint256 nAssets, + int128 tradeFrac, + int128 targetSlippage, + bool assumeEqual + ) internal pure returns (int128) { + require(nAssets > 1, "LMSR: n>1 required"); + require(q > int128(0), "LMSR: q>0"); + // f must be in (0,1) + int128 f = tradeFrac; + require(f > int128(0), "LMSR: f=0"); + require(f < ONE, "LMSR: f>=1"); + + int128 one = _one(); + + // Top-level input debug + console2.log("computeBFromSlippageCore: inputs"); + console2.log("q (64.64)"); + console2.logInt(q); + console2.log("nAssets"); + console2.logUint(nAssets); + console2.log("tradeFrac f (64.64)"); + console2.logInt(f); + console2.log("targetSlippage S (64.64)"); + console2.logInt(targetSlippage); + console2.log("assumeEqual"); + console2.logUint(assumeEqual ? 1 : 0); + + if (assumeEqual) { + // Closed-form equal-asset simplification for an n-asset pool: + // Let s be the target relative increase in OTHER assets' price-share when + // removing fraction f of a single asset. For equal inventories we derive: + // E = exp(-y*f) = (1 - s*(n-1)) / (1 + s) + // where y = q / b. Therefore: + // y = -ln(E) / f + // b = q / y = q * f / (-ln(E)) + + int128 nMinus1 = ABDKMath64x64.fromUInt(nAssets - 1); + int128 numerator = one.sub(targetSlippage.mul(nMinus1)); // 1 - s*(n-1) + int128 denominator = one.add(targetSlippage); // 1 + s + + console2.log("equal-case intermediates:"); + console2.log("numerator = 1 - s*(n-1)"); + console2.logInt(numerator); + console2.log("denominator = 1 + s"); + console2.logInt(denominator); + + require(numerator > int128(0), "LMSR: s too large for n"); // ensures ratio>0 + + int128 ratio = numerator.div(denominator); // E candidate + console2.log("E candidate (ratio = numerator/denominator)"); + console2.logInt(ratio); + + // E must be strictly between 0 and 1 for a positive y + require(ratio > int128(0) && ratio < one, "LMSR: bad E ratio"); + + int128 lnE = _ln(ratio); // ln(E) < 0 + console2.log("ln(E)"); + console2.logInt(lnE); + + // y = -ln(E) / f + int128 y = lnE.neg().div(f); + console2.log("y = -ln(E)/f"); + console2.logInt(y); + require(y > int128(0), "LMSR: y<=0"); + + int128 b = q.div(y); + console2.log("b = q / y (computed)"); + console2.logInt(b); + require(b > int128(0), "LMSR: b<=0"); + + // Simulate the slippage using this b to verify + int128 expArg = y.mul(f).neg(); + int128 E_sim = _exp(expArg); + int128 n64 = ABDKMath64x64.fromUInt(nAssets); + int128 nMinus1_64 = ABDKMath64x64.fromUInt(nAssets - 1); + int128 simulatedSlippage = n64.div(nMinus1_64.add(E_sim)).sub(_one()); + console2.log("simulatedSlippage (using computed b)"); + console2.logInt(simulatedSlippage); + + return b; + } else { + // Heterogeneous / general case (original derivation): + // E = exp(-y * f) where y = q / b + // and E = (1+s) * (n-1) / (n - (1+s)) + // so y = -ln(E) / f and b = q / y. + int128 onePlusS = one.add(targetSlippage); + + console2.log("heterogeneous intermediates:"); + console2.log("onePlusS = 1 + s"); + console2.logInt(onePlusS); + + int128 n64 = ABDKMath64x64.fromUInt(nAssets); + int128 nMinus1_64 = ABDKMath64x64.fromUInt(nAssets - 1); + + // denom = n - (1+s) + int128 denom = n64.sub(onePlusS); + console2.log("denom = n - (1+s)"); + console2.logInt(denom); + + // Guard and clamp pathological cases similar to previous logic + int128 eps = ABDKMath64x64.divu(1, 1_000_000_000); // small epsilon ~1e-9 in Q64.64 + if (onePlusS >= n64) { + console2.log('clamping'); + onePlusS = n64.sub(eps); + denom = n64.sub(onePlusS); + } + require(denom > int128(0), "LMSR: bad slippage or n"); + + int128 prod = onePlusS.mul(nMinus1_64); + console2.log("prod = (1+s)*(n-1)"); + console2.logInt(prod); + + if (!(prod > int128(0) && denom < prod)) { + if (denom >= prod) { + onePlusS = onePlusS.sub(eps); + denom = n64.sub(onePlusS); + prod = onePlusS.mul(nMinus1_64); + } + require(prod > int128(0) && denom < prod, "LMSR: slippage out of range"); + } + + // Correct E candidate for the slippage relation: + // E = (1 - s*(n-1)) / (1 + s) + int128 E_candidate = (one.sub(targetSlippage.mul(nMinus1_64))).div(onePlusS); + console2.log("E candidate ((1 - s*(n-1)) / (1+s))"); + console2.logInt(E_candidate); + + // Compute ln(E) directly from the ratio E_candidate for improved numerical stability + int128 lnE = _ln(E_candidate); + console2.log("lnE = ln(E_candidate)"); + console2.logInt(lnE); + + // y = -ln(E) / f + int128 y = lnE.neg().div(f); + console2.log("y = -ln(E)/f"); + console2.logInt(y); + require(y > int128(0), "LMSR: y<=0"); + + // b = q / y + int128 b = q.div(y); + console2.log("b = q / y (computed)"); + console2.logInt(b); + require(b > int128(0), "LMSR: b<=0"); + + // Simulate slippage using this b to verify + int128 expArg = y.mul(f).neg(); + int128 E_sim = _exp(expArg); + int128 simulatedSlippage = n64.div(nMinus1_64.add(E_sim)).sub(_one()); + console2.log("simulatedSlippage (heterogeneous)"); + console2.logInt(simulatedSlippage); + + return b; + } + } + + /// @notice De-initialize the LMSR state when the entire pool is drained. + /// This resets the state so the pool can be re-initialized by init(...) on next mint. + function deinit(State storage s) internal { + console2.log("LMSR.deinit: resetting state"); + + // Reset core state + s.nAssets = 0; + s.kappa = int128(0); + + // Clear qInternal array + delete s.qInternal; + + // Note: init(...) will recompute kappa and nAssets on first mint. + } + + /// @notice Compute M (shift) and Z (sum of exponentials) dynamically + function _computeMAndZ(int128 b, int128[] memory qInternal) private pure returns (int128 M, int128 Z) { + require(qInternal.length > 0, "LMSR: no assets"); + + // Precompute reciprocal of b to replace divisions with multiplications in the loop + int128 invB = ABDKMath64x64.div(ONE, b); + + // Initialize with the first element + uint len = qInternal.length; + M = qInternal[0].mul(invB); + Z = ONE; // only the first term contributes exp(0) = 1 + + // One-pass accumulation with on-the-fly recentering + for (uint i = 1; i < len; ) { + int128 yi = qInternal[i].mul(invB); + if (yi <= M) { + // Add exp(yi - M) to Z + Z = Z.add(_exp(yi.sub(M))); + } else { + // When a larger yi is found, rescale Z to the new center M := yi + // New Z = Z * exp(M - yi) + 1 + Z = Z.mul(_exp(M.sub(yi))).add(ONE); + M = yi; + } + unchecked { i++; } + } + } + + /// @notice Compute all e[i] = exp(z[i]) values dynamically + function _computeE(int128 b, int128[] memory qInternal, int128 M) private pure returns (int128[] memory e) { + uint len = qInternal.length; + e = new int128[](len); + + // Precompute reciprocal of b to avoid repeated divisions + int128 invB = ABDKMath64x64.div(ONE, b); + + for (uint i = 0; i < len; ) { + int128 y_i = qInternal[i].mul(invB); + int128 z_i = y_i.sub(M); + e[i] = _exp(z_i); + unchecked { i++; } + } + } + + /// @notice Compute r0 = e_i / e_j directly as exp((q_i - q_j) / b) + /// This avoids computing two separate exponentials and a division + function _computeR0(int128 b, int128[] memory qInternal, uint256 i, uint256 j) private pure returns (int128) { + return _exp(qInternal[i].sub(qInternal[j]).div(b)); + } + + + /* -------------------- + Low-level helpers + -------------------- */ + + // Precomputed Q64.64 representation of 1.0 (1 << 64). + int128 private constant ONE = 0x10000000000000000; + // Precomputed Q64.64 representation of 32.0 for exp guard + int128 private constant EXP_LIMIT = 0x200000000000000000; + + function _exp(int128 x) private pure returns (int128) { return ABDKMath64x64.exp(x); } + function _ln(int128 x) private pure returns (int128) { return ABDKMath64x64.ln(x); } + function _one() private pure returns (int128) { return ONE; } + + /// @notice Compute size metric S(q) = sum of all asset quantities + function _computeSizeMetric(int128[] memory qInternal) private pure returns (int128) { + int128 total = int128(0); + for (uint i = 0; i < qInternal.length; ) { + total = total.add(qInternal[i]); + unchecked { i++; } + } + return total; + } + + /// @notice Compute b from kappa and current asset quantities + function _computeB(State storage s) private view returns (int128) { + int128 sizeMetric = _computeSizeMetric(s.qInternal); + require(sizeMetric > int128(0), "LMSR: size metric zero"); + return s.kappa.mul(sizeMetric); + } + +} diff --git a/src/PartyPool.sol b/src/PartyPool.sol new file mode 100644 index 0000000..3e02668 --- /dev/null +++ b/src/PartyPool.sol @@ -0,0 +1,818 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "@abdk/ABDKMath64x64.sol"; +import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; +import "./LMSRStabilized.sol"; +import "./IPartyPool.sol"; +import "./IPartyFlashCallback.sol"; + +/// @title PartyPool - LMSR-backed multi-asset pool with LP ERC20 token +/// @notice Uses LMSRStabilized library; stores per-token uint bases to convert to/from 64.64 fixed point. +/// - Caches qInternal[] (int128 64.64) and cachedUintBalances[] to minimize balanceOf() calls. +/// - swap and swapToLimit mimic core lib; mint/burn call updateForProportionalChange() and manage LP tokens. +contract PartyPool is IPartyPool, ERC20, ReentrancyGuard { + using ABDKMath64x64 for int128; + using LMSRStabilized for LMSRStabilized.State; + using SafeERC20 for IERC20; + + + // + // Immutable pool configuration + // + + address[] public tokens; // effectively immutable since there is no interface to change the tokens + function numTokens() external view returns (uint256) { return tokens.length; } + function allTokens() external view returns (address[] memory) { return tokens; } + + // NOTE that the slippage target is only exactly achieved in completely balanced pools where all assets are + // priced the same. This target is actually a minimum slippage that the pool imposes on traders, and the actual + // slippage cost can be multiples bigger in practice due to pool inventory imbalances. + int128 public immutable tradeFrac; // slippage target trade size as a fraction of one asset's inventory + int128 public immutable targetSlippage; // target slippage applied to that trade size + + // fee in parts-per-million (ppm), taken from inputs before swaps + uint256 public immutable swapFeePpm; + + // flash loan fee in parts-per-million (ppm) + uint256 public immutable flashFeePpm; + + // + // Internal state + // + + LMSRStabilized.State internal lmsr; + + // Cached on-chain balances (uint) and internal 64.64 representation + // balance / base = internal + uint256[] internal cachedUintBalances; + uint256[] internal bases; // per-token uint base used to scale token amounts <-> internal + + mapping(address=>uint) public tokenAddressToIndexPlusOne; // Uses index+1 so a result of 0 indicates a failed lookup + + uint256 public constant LP_SCALE = 1e18; // Scale used to convert LMSR lastTotal (Q64.64) into LP token units (uint) + + /// @param name_ LP token name + /// @param symbol_ LP token symbol + /// @param _tokens token addresses (n) + /// @param _bases scaling bases for each token (n) - used when converting to/from internal 64.64 amounts + /// @param _tradeFrac trade fraction in 64.64 fixed-point (as used by LMSR) + /// @param _targetSlippage target slippage in 64.64 fixed-point (as used by LMSR) + /// @param _swapFeePpm fee in parts-per-million, taken from swap input amounts before LMSR calculations + /// @param _flashFeePpm fee in parts-per-million, taken for flash loans + constructor( + string memory name_, + string memory symbol_, + address[] memory _tokens, + uint256[] memory _bases, + int128 _tradeFrac, + int128 _targetSlippage, + uint256 _swapFeePpm, + uint256 _flashFeePpm + ) ERC20(name_, symbol_) { + require(_tokens.length > 1, "Pool: need >1 asset"); + require(_tokens.length == _bases.length, "Pool: lengths mismatch"); + tokens = _tokens; + bases = _bases; + tradeFrac = _tradeFrac; + targetSlippage = _targetSlippage; + require(_swapFeePpm < 1_000_000, "Pool: fee >= ppm"); + swapFeePpm = _swapFeePpm; + require(_flashFeePpm < 1_000_000, "Pool: flash fee >= ppm"); + flashFeePpm = _flashFeePpm; + + uint256 n = _tokens.length; + + // Initialize LMSR state nAssets; full init occurs on first mint when quantities are known. + lmsr.nAssets = n; + + // Initialize token address to index mapping + for (uint i = 0; i < n;) { + tokenAddressToIndexPlusOne[_tokens[i]] = i + 1; + unchecked {i++;} + } + + // Initialize caches to zero + cachedUintBalances = new uint256[](n); + } + + + /* ---------------------- + Initialization / Mint / Burn (LP token managed) + ---------------------- */ + + /// @notice Calculate the proportional deposit amounts required for a given LP token amount + /// @param lpTokenAmount The amount of LP tokens desired + /// @return depositAmounts Array of token amounts to deposit (rounded up) + function computeMintAmounts(uint256 lpTokenAmount) public view returns (uint256[] memory depositAmounts) { + uint256 n = tokens.length; + depositAmounts = new uint256[](n); + + // If this is the first mint or pool is empty, return zeros + // For first mint, tokens should already be transferred to the pool + if (totalSupply() == 0 || lmsr.nAssets == 0) { + return depositAmounts; // Return zeros, initial deposit handled differently + } + + // Calculate deposit based on current proportions + uint256 totalLpSupply = totalSupply(); + + // lpTokenAmount / totalLpSupply = depositAmount / currentBalance + // Therefore: depositAmount = (lpTokenAmount * currentBalance) / totalLpSupply + // We round up to protect the pool + for (uint i = 0; i < n; i++) { + uint256 currentBalance = cachedUintBalances[i]; + // Calculate with rounding up: (a * b + c - 1) / c + depositAmounts[i] = (lpTokenAmount * currentBalance + totalLpSupply - 1) / totalLpSupply; + } + + return depositAmounts; + } + + /// @notice Calculate the proportional withdrawal amounts for a given LP token amount + /// @param lpTokenAmount The amount of LP tokens to burn + /// @return withdrawAmounts Array of token amounts to withdraw (rounded down) + function computeBurnAmounts(uint256 lpTokenAmount) external view returns (uint256[] memory withdrawAmounts) { + return _computeBurnAmounts(lpTokenAmount); + } + + function _computeBurnAmounts(uint256 lpTokenAmount) internal view returns (uint256[] memory withdrawAmounts) { + uint256 n = tokens.length; + withdrawAmounts = new uint256[](n); + + // If supply is zero or pool uninitialized, return zeros + if (totalSupply() == 0 || lmsr.nAssets == 0) { + return withdrawAmounts; // Return zeros, nothing to withdraw + } + + // Calculate withdrawal amounts based on current proportions + uint256 totalLpSupply = totalSupply(); + + // withdrawAmount = floor(lpTokenAmount * currentBalance / totalLpSupply) + for (uint i = 0; i < n; i++) { + uint256 currentBalance = cachedUintBalances[i]; + withdrawAmounts[i] = (lpTokenAmount * currentBalance) / totalLpSupply; + } + + return withdrawAmounts; + } + + /// @notice Proportional mint (or initial supply if first call). + /// For initial supply: assumes tokens have already been transferred to the pool + /// For subsequent mints: payer must approve tokens beforehand, receiver gets the LP tokens + /// @param payer address that provides the input tokens (ignored for initial deposit) + /// @param receiver address that receives the LP tokens + /// @param lpTokenAmount desired amount of LP tokens to mint (ignored for initial deposit) + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function mint(address payer, address receiver, uint256 lpTokenAmount, uint256 deadline) external nonReentrant { + require(deadline == 0 || block.timestamp <= deadline, "mint: deadline exceeded"); + uint256 n = tokens.length; + // Check if this is initial deposit + bool isInitialDeposit = totalSupply() == 0 || lmsr.nAssets == 0; + + require(lpTokenAmount > 0 || isInitialDeposit, "mint: zero LP amount"); + + // Capture old pool size metric (scaled) by computing from current balances + uint256 oldScaled = 0; + if (!isInitialDeposit) { + int128 oldTotal = _computeSizeMetric(lmsr.qInternal); + oldScaled = ABDKMath64x64.mulu(oldTotal, LP_SCALE); + } + + // For non-initial deposits, transfer tokens from payer + uint256[] memory depositAmounts = new uint256[](n); + + if (!isInitialDeposit) { + // Calculate required deposit amounts for the desired LP tokens + depositAmounts = computeMintAmounts(lpTokenAmount); + + // Transfer in all token amounts + for (uint i = 0; i < n; ) { + if (depositAmounts[i] > 0) { + _safeTransferFrom(tokens[i], payer, address(this), depositAmounts[i]); + } + unchecked { i++; } + } + } + + // Update cached balances for all assets + int128[] memory newQInternal = new int128[](n); + for (uint i = 0; i < n; ) { + uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); + cachedUintBalances[i] = bal; + newQInternal[i] = _uintToInternalFloor(bal, bases[i]); + + // For initial deposit, record the actual deposited amounts + if (isInitialDeposit) { + depositAmounts[i] = bal; + } + + unchecked { i++; } + } + + // If first time, call init, otherwise update proportional change. + if (isInitialDeposit) { + // Initialize the stabilized LMSR state + lmsr.init(newQInternal, tradeFrac, targetSlippage); + } else { + // Update for proportional change + lmsr.updateForProportionalChange(newQInternal); + } + + // Compute actual LP tokens to mint based on change in size metric (scaled) + // floor truncation rounds in favor of the pool + int128 newTotal = _computeSizeMetric(newQInternal); + uint256 newScaled = ABDKMath64x64.mulu(newTotal, LP_SCALE); + uint256 actualLpToMint; + + if (isInitialDeposit) { + // Initial provisioning: mint newScaled (as LP units) + actualLpToMint = newScaled; + } else { + require(oldScaled > 0, "mint: oldScaled zero"); + uint256 delta = (newScaled > oldScaled) ? (newScaled - oldScaled) : 0; + // Proportional issuance: totalSupply * delta / oldScaled + if (delta > 0) { + // floor truncation rounds in favor of the pool + actualLpToMint = (totalSupply() * delta) / oldScaled; + } else { + actualLpToMint = 0; + } + } + + // For subsequent mints, ensure the calculated LP amount is not too different from requested + if (!isInitialDeposit) { + // Allow for some rounding error but ensure we're not far off from requested amount + require(actualLpToMint > 0, "mint: zero LP minted"); + + // Allow actual amount to be at most 0.00001% less than requested + // This accounts for rounding in deposit calculations + uint256 minAcceptable = lpTokenAmount * 99_999 / 100_000; + require(actualLpToMint >= minAcceptable, "mint: insufficient LP minted"); + } + + console2.log('actualLpToMint', actualLpToMint); + require( actualLpToMint > 0, "mint: zero LP amount"); + _mint(receiver, actualLpToMint); + emit Mint(payer, receiver, depositAmounts, actualLpToMint); + } + + /// @notice Burn LP tokens and withdraw the proportional basket to receiver. + /// Payer must own the LP tokens; withdraw amounts are computed from current proportions. + /// @param payer address that provides the LP tokens to burn + /// @param receiver address that receives the withdrawn tokens + /// @param lpAmount amount of LP tokens to burn (proportional withdrawal) + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function burn(address payer, address receiver, uint256 lpAmount, uint256 deadline) external nonReentrant { + require(deadline == 0 || block.timestamp <= deadline, "burn: deadline exceeded"); + uint256 n = tokens.length; + require(lpAmount > 0, "burn: zero lp"); + + uint256 supply = totalSupply(); + require(supply > 0, "burn: empty supply"); + require(lmsr.nAssets > 0, "burn: uninit pool"); + require(balanceOf(payer) >= lpAmount, "burn: insufficient LP"); + + // Refresh cached balances to reflect current on-chain balances before computing withdrawal amounts + for (uint i = 0; i < n; ) { + uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); + cachedUintBalances[i] = bal; + unchecked { i++; } + } + + // Compute proportional withdrawal amounts for the requested LP amount (rounded down) + uint256[] memory withdrawAmounts = _computeBurnAmounts(lpAmount); + + // Transfer underlying tokens out to receiver according to computed proportions + for (uint i = 0; i < n; ) { + if (withdrawAmounts[i] > 0) { + _safeTransfer(tokens[i], receiver, withdrawAmounts[i]); + } + unchecked { i++; } + } + + // Update cached balances and internal q for all assets + int128[] memory newQInternal = new int128[](n); + for (uint i = 0; i < n; ) { + uint256 bal = IERC20(tokens[i]).balanceOf(address(this)); + cachedUintBalances[i] = bal; + newQInternal[i] = _uintToInternalFloor(bal, bases[i]); + unchecked { i++; } + } + + // Apply proportional update or deinitialize if drained + bool allZero = true; + for (uint i = 0; i < n; ) { + if (newQInternal[i] != int128(0)) { + allZero = false; + break; + } + unchecked { i++; } + } + + if (allZero) { + lmsr.deinit(); + } else { + lmsr.updateForProportionalChange(newQInternal); + } + + // Burn exactly the requested LP amount from payer (authorization via allowance) + if (msg.sender != payer) { + uint256 allowed = allowance(payer, msg.sender); + require(allowed >= lpAmount, "burn: allowance insufficient"); + _approve(payer, msg.sender, allowed - lpAmount); + } + _burn(payer, lpAmount); + + emit Burn(payer, receiver, withdrawAmounts, lpAmount); + } + + /* ---------------------- + Swaps + ---------------------- */ + + /// @notice Swap input token i -> token j. Payer must approve token i. + /// @param payer address of the account that pays for the swap + /// @param receiver address that will receive the output tokens + /// @param i index of input asset + /// @param j index of output asset + /// @param maxAmountIn maximum amount of token i (uint256) to transfer in (inclusive of fees) + /// @param limitPrice maximum acceptable marginal price (64.64 fixed point). Pass 0 to ignore. + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + /// @return amountIn actual input used (uint256), amountOut actual output sent (uint256) + function swap( + address payer, + address receiver, + uint256 i, + uint256 j, + uint256 maxAmountIn, + int128 limitPrice, + uint256 deadline + ) external nonReentrant returns (uint256 amountIn, uint256 amountOut) { + uint256 n = tokens.length; + require(i < n && j < n, "swap: idx"); + require(maxAmountIn > 0, "swap: input zero"); + require(deadline == 0 || block.timestamp <= deadline, "swap: deadline exceeded"); + + // Read previous balances for affected assets + uint256 prevBalI = IERC20(tokens[i]).balanceOf(address(this)); + uint256 prevBalJ = IERC20(tokens[j]).balanceOf(address(this)); + + // Calculate fee (ceiling) and net amount + (, uint256 netUintForSwap) = _computeFee(maxAmountIn); + + // Convert the net amount to internal (floor) + int128 deltaInternalI = _uintToInternalFloor(netUintForSwap, bases[i]); + require(deltaInternalI > int128(0), "swap: input too small after fee"); + + // Make sure LMSR state exists + require(lmsr.nAssets > 0, "swap: empty pool"); + + // Compute swap amounts in internal space using exact-input logic (with limitPrice) + (int128 amountInInternalUsed, int128 amountOutInternal) = lmsr.swapAmountsForExactInput( + i, + j, + deltaInternalI, + limitPrice + ); + + // Convert actual used input internal -> uint (ceiling to protect the pool) + uint256 amountInUint = _internalToUintCeil(amountInInternalUsed, bases[i]); + + // Total transfer amount includes fee calculated on the actual used input (ceiling) + uint256 totalTransferAmount = amountInUint; + if (swapFeePpm > 0) { + uint256 feeOnUsed = _ceilFee(amountInUint, swapFeePpm); + totalTransferAmount += feeOnUsed; + } + + // Ensure we do not attempt to transfer more than the caller specified as maximum + require(totalTransferAmount > 0, 'swap: input zero'); + require(totalTransferAmount <= maxAmountIn, "swap: transfer exceeds max"); + + // Transfer the exact amount from payer and require exact receipt (revert on fee-on-transfer) + _safeTransferFrom(tokens[i], payer, address(this), totalTransferAmount); + uint256 balIAfter = IERC20(tokens[i]).balanceOf(address(this)); + require(balIAfter == prevBalI + totalTransferAmount, "swap: non-standard tokenIn"); + + // Compute output uint amount (floor) + uint256 amountOutUint = _internalToUintFloor(amountOutInternal, bases[j]); + require(amountOutUint > 0, "swap: output zero"); + + // Transfer output to receiver and verify exact decrease + _safeTransfer(tokens[j], receiver, amountOutUint); + uint256 balJAfter = IERC20(tokens[j]).balanceOf(address(this)); + require(balJAfter == prevBalJ - amountOutUint, "swap: non-standard tokenOut"); + + // Update cached uint balances for i and j using actual balances + cachedUintBalances[i] = balIAfter; + cachedUintBalances[j] = balJAfter; + + // Apply swap to LMSR state with the internal amounts actually used + // (fee is already accounted for in the reduced input amount) + lmsr.applySwap(i, j, amountInInternalUsed, amountOutInternal); + + emit Swap(payer, receiver, tokens[i], tokens[j], totalTransferAmount, amountOutUint); + + return (totalTransferAmount, amountOutUint); + } + + /// @notice Swap up to the price limit; computes max input to reach limit then performs swap. + /// If the pool can't fill entirely because of balances, it caps appropriately and returns actuals. + /// Payer must approve token i for the exact computed input amount. + /// @param deadline timestamp after which the transaction will revert. Pass 0 to ignore. + function swapToLimit( + address payer, + address receiver, + uint256 i, + uint256 j, + int128 limitPrice, + uint256 deadline + ) external returns (uint256 amountInUsed, uint256 amountOut) { + uint256 n = tokens.length; + require(i < n && j < n, "swapToLimit: idx"); + require(limitPrice > int128(0), "swapToLimit: limit <= 0"); + require(deadline == 0 || block.timestamp <= deadline, "swapToLimit: deadline exceeded"); + + // Ensure LMSR state exists + require(lmsr.nAssets > 0, "swapToLimit: pool uninitialized"); + + // Read previous balances for affected assets + uint256 prevBalI = IERC20(tokens[i]).balanceOf(address(this)); + uint256 prevBalJ = IERC20(tokens[j]).balanceOf(address(this)); + + // Compute maxima in internal space using library + (int128 amountInInternalMax, int128 amountOutInternal) = lmsr.swapAmountsForPriceLimit(i, j, limitPrice); + + // Calculate how much input will be needed with fee included (ceiling to protect the pool) + uint256 amountInUsedUint = _internalToUintCeil(amountInInternalMax, bases[i]); + require(amountInUsedUint > 0, "swapToLimit: input zero"); + + // Total transfer amount is the input amount including what will be taken as fee (ceiling) + uint256 totalTransferAmount = amountInUsedUint; + + if (swapFeePpm > 0) { + uint256 feeOnUsed = _ceilFee(amountInUsedUint, swapFeePpm); + totalTransferAmount += feeOnUsed; + } + + // Transfer the exact amount needed from payer and require exact receipt (revert on fee-on-transfer) + _safeTransferFrom(tokens[i], payer, address(this), totalTransferAmount); + uint256 balIAfter = IERC20(tokens[i]).balanceOf(address(this)); + require(balIAfter == prevBalI + totalTransferAmount, "swapToLimit: non-standard tokenIn"); + + // Compute output amount (floor) + uint256 amountOutUint = _internalToUintFloor(amountOutInternal, bases[j]); + require(amountOutUint > 0, "swapToLimit: output zero"); + + // Transfer output to receiver and verify exact decrease + _safeTransfer(tokens[j], receiver, amountOutUint); + uint256 balJAfter = IERC20(tokens[j]).balanceOf(address(this)); + require(balJAfter == prevBalJ - amountOutUint, "swapToLimit: non-standard tokenOut"); + + // Update caches to actual balances + cachedUintBalances[i] = balIAfter; + cachedUintBalances[j] = balJAfter; + + // Apply swap to LMSR state with the internal amounts + // (fee is already part of the reduced effective input) + lmsr.applySwap(i, j, amountInInternalMax, amountOutInternal); + + emit Swap(payer, receiver, tokens[i], tokens[j], amountInUsedUint, amountOutUint); + + return (amountInUsedUint, amountOutUint); + } + + /// @notice Ceiling fee helper: computes ceil(x * feePpm / 1_000_000) + function _ceilFee(uint256 x, uint256 feePpm) internal pure returns (uint256) { + if (feePpm == 0) return 0; + // ceil division: (num + denom - 1) / denom + return (x * feePpm + 1_000_000 - 1) / 1_000_000; + } + + /// @notice Compute fee and net amounts for a gross input (fee rounded up to favor the pool). + /// @return feeUint fee taken (uint) and netUint remaining for protocol use (uint) + function _computeFee(uint256 gross) internal view returns (uint256 feeUint, uint256 netUint) { + if (swapFeePpm == 0) { + return (0, gross); + } + feeUint = _ceilFee(gross, swapFeePpm); + netUint = gross - feeUint; + } + + /// @notice Convenience: return gross = net + fee(net) using ceiling for fee. + function _addFee(uint256 netUint) internal view returns (uint256 gross) { + if (swapFeePpm == 0) return netUint; + uint256 fee = _ceilFee(netUint, swapFeePpm); + return netUint + fee; + } + + // --- New events for single-token mint/burn flows --- + // Note: events intentionally avoid exposing internal ΔS and avoid duplicating LP mint/burn data + // which is already present in the standard Mint/Burn events. + + /// @notice Single-token mint: deposit a single token, charge swap-LMSR cost, and mint LP. + /// @param payer who transfers the input token + /// @param receiver who receives the minted LP tokens + /// @param i index of the input token + /// @param maxAmountIn maximum uint token input (inclusive of fee) + /// @param deadline optional deadline + /// @return lpMinted actual LP minted (uint) + function swapMint( + address payer, + address receiver, + uint256 i, + uint256 maxAmountIn, + uint256 deadline + ) external nonReentrant returns (uint256 lpMinted) { + uint256 n = tokens.length; + require(i < n, "swapMint: idx"); + require(maxAmountIn > 0, "swapMint: input zero"); + require(deadline == 0 || block.timestamp <= deadline, "swapMint: deadline"); + + // Ensure pool initialized + require(lmsr.nAssets > 0, "swapMint: uninit pool"); + + // compute fee on gross maxAmountIn to get an initial net estimate (we'll recompute based on actual used) + (, uint256 netUintGuess) = _computeFee(maxAmountIn); + + // Convert the net guess to internal (floor) + int128 netInternalGuess = _uintToInternalFloor(netUintGuess, bases[i]); + require(netInternalGuess > int128(0), "swapMint: input too small after fee"); + + // Use LMSR view to determine actual internal consumed and size-increase (ΔS) for mint + (int128 amountInInternalUsed, int128 sizeIncreaseInternal) = lmsr.swapAmountsForMint(i, netInternalGuess); + + // amountInInternalUsed may be <= netInternalGuess. Convert to uint (ceil) to determine actual transfer + uint256 amountInUint = _internalToUintCeil(amountInInternalUsed, bases[i]); + require(amountInUint > 0, "swapMint: input zero after internal conversion"); + + // Compute fee on the actual used input and total transfer amount (ceiling) + uint256 feeUintActual = _ceilFee(amountInUint, swapFeePpm); + uint256 totalTransfer = amountInUint + feeUintActual; + require(totalTransfer > 0 && totalTransfer <= maxAmountIn, "swapMint: transfer exceeds max"); + + // Record pre-balance and transfer tokens from payer, require exact receipt (revert on fee-on-transfer) + uint256 prevBalI = IERC20(tokens[i]).balanceOf(address(this)); + _safeTransferFrom(tokens[i], payer, address(this), totalTransfer); + uint256 balIAfter = IERC20(tokens[i]).balanceOf(address(this)); + require(balIAfter == prevBalI + totalTransfer, "swapMint: non-standard tokenIn"); + + // Update cached uint balances for token i (only i changed externally) + cachedUintBalances[i] = balIAfter; + + // Compute old and new scaled size metrics to determine LP minted + int128 oldTotal = _computeSizeMetric(lmsr.qInternal); + require(oldTotal > int128(0), "swapMint: zero total"); + uint256 oldScaled = ABDKMath64x64.mulu(oldTotal, LP_SCALE); + + int128 newTotal = oldTotal.add(sizeIncreaseInternal); + uint256 newScaled = ABDKMath64x64.mulu(newTotal, LP_SCALE); + + uint256 actualLpToMint; + if (totalSupply() == 0) { + // If somehow supply zero (shouldn't happen as lmsr.nAssets>0), mint newScaled + actualLpToMint = newScaled; + } else { + require(oldScaled > 0, "swapMint: oldScaled zero"); + uint256 delta = (newScaled > oldScaled) ? (newScaled - oldScaled) : 0; + if (delta > 0) { + // floor truncation rounds in favor of pool + actualLpToMint = (totalSupply() * delta) / oldScaled; + } else { + actualLpToMint = 0; + } + } + + require(actualLpToMint > 0, "swapMint: zero LP minted"); + + // Update LMSR internal state: scale qInternal proportionally by newTotal/oldTotal + int128[] memory newQInternal = new int128[](n); + for (uint256 idx = 0; idx < n; idx++) { + // newQInternal[idx] = qInternal[idx] * (newTotal / oldTotal) + newQInternal[idx] = lmsr.qInternal[idx].mul(newTotal).div(oldTotal); + } + + // Update cached internal and kappa via updateForProportionalChange + lmsr.updateForProportionalChange(newQInternal); + + // Note: we updated cachedUintBalances[i] above via reading balance; other token uint balances did not + // change externally (they were not transferred in). We keep cachedUintBalances for others unchanged. + // Mint LP tokens to receiver + _mint(receiver, actualLpToMint); + + // Emit SwapMint event with gross transfer, net input and fee (planned exact-in) + emit SwapMint(payer, receiver, i, totalTransfer, amountInUint, feeUintActual); + + // Emit standard Mint event which records deposit amounts and LP minted + emit Mint(payer, receiver, new uint256[](n), actualLpToMint); + // Note: depositAmounts array omitted (empty) since swapMint uses single-token input + + return actualLpToMint; + } + + /// @notice Burn LP tokens then swap the redeemed proportional basket into a single asset `i` and send to receiver. + /// @param payer who burns LP tokens + /// @param receiver who receives the single asset + /// @param lpAmount amount of LP tokens to burn + /// @param i index of target asset to receive + /// @param deadline optional deadline + /// @return amountOutUint uint amount of asset i sent to receiver + function burnSwap( + address payer, + address receiver, + uint256 lpAmount, + uint256 i, + uint256 deadline + ) external nonReentrant returns (uint256 amountOutUint) { + uint256 n = tokens.length; + require(i < n, "burnSwap: idx"); + require(lpAmount > 0, "burnSwap: zero lp"); + require(deadline == 0 || block.timestamp <= deadline, "burnSwap: deadline"); + + uint256 supply = totalSupply(); + require(supply > 0, "burnSwap: empty supply"); + require(balanceOf(payer) >= lpAmount, "burnSwap: insufficient LP"); + + // alpha = lpAmount / supply as Q64.64 + int128 alpha = ABDKMath64x64.divu(lpAmount, supply); + + // Use LMSR view to compute single-asset payout and burned size-metric + (int128 payoutInternal, ) = lmsr.swapAmountsForBurn(i, alpha); + + // Convert payoutInternal -> uint (floor) to favor pool + amountOutUint = _internalToUintFloor(payoutInternal, bases[i]); + require(amountOutUint > 0, "burnSwap: output zero"); + + // Transfer the payout to receiver + _safeTransfer(tokens[i], receiver, amountOutUint); + + // Burn LP tokens from payer (authorization via allowance) + if (msg.sender != payer) { + uint256 allowed = allowance(payer, msg.sender); + require(allowed >= lpAmount, "burnSwap: allowance insufficient"); + _approve(payer, msg.sender, allowed - lpAmount); + } + _burn(payer, lpAmount); + + // Update cached balances by reading on-chain balances for all tokens + int128[] memory newQInternal = new int128[](n); + for (uint256 idx = 0; idx < n; idx++) { + uint256 bal = IERC20(tokens[idx]).balanceOf(address(this)); + cachedUintBalances[idx] = bal; + newQInternal[idx] = _uintToInternalFloor(bal, bases[idx]); + } + + // Emit BurnSwap with public-facing info only (do not expose ΔS or LP burned) + emit BurnSwap(payer, receiver, i, amountOutUint); + + // If entire pool drained, deinit; else update proportionally + bool allZero = true; + for (uint256 idx = 0; idx < n; idx++) { + if (newQInternal[idx] != int128(0)) { allZero = false; break; } + } + if (allZero) { + lmsr.deinit(); + } else { + lmsr.updateForProportionalChange(newQInternal); + } + + emit Burn(payer, receiver, new uint256[](n), lpAmount); + return amountOutUint; + } + + + function computeFlashRepaymentAmounts(uint256[] memory loanAmounts) external view + returns (uint256[] memory repaymentAmounts) { + repaymentAmounts = new uint256[](tokens.length); + for (uint256 i = 0; i < tokens.length; i++) { + uint256 amount = loanAmounts[i]; + if (amount > 0) { + repaymentAmounts[i] = amount + _ceilFee(amount, flashFeePpm); + } + } + } + + + /// @notice Receive token0 and/or token1 and pay it back, plus a fee, in the callback + /// @dev The caller of this method receives a callback in the form of IPartyFlashCallback#partyFlashCallback + /// @param recipient The address which will receive the token amounts + /// @param amounts The amount of each token to send + /// @param data Any data to be passed through to the callback + function flash( + address recipient, + uint256[] memory amounts, + bytes calldata data + ) external nonReentrant { + require(recipient != address(0), "flash: zero recipient"); + require(amounts.length == tokens.length, "flash: amounts length mismatch"); + + // Calculate repayment amounts for each token including fee + uint256[] memory repaymentAmounts = new uint256[](tokens.length); + + // Store initial balances to verify repayment later + uint256[] memory initialBalances = new uint256[](tokens.length); + + // Track if any token amount is non-zero + bool hasNonZeroAmount = false; + + // Process each token, skipping those with zero amounts + for (uint256 i = 0; i < tokens.length; i++) { + uint256 amount = amounts[i]; + + if (amount > 0) { + hasNonZeroAmount = true; + + // Calculate repayment amount with fee (ceiling) + repaymentAmounts[i] = amount + _ceilFee(amount, flashFeePpm); + + // Record initial balance + initialBalances[i] = IERC20(tokens[i]).balanceOf(address(this)); + + // Transfer token to recipient + _safeTransfer(tokens[i], recipient, amount); + } + } + + // Ensure at least one token is being borrowed + require(hasNonZeroAmount, "flash: no tokens requested"); + + // Call flash callback with expected repayment amounts + IPartyFlashCallback(msg.sender).partyFlashCallback(amounts, repaymentAmounts, data); + + // Verify repayment amounts for tokens that were borrowed + for (uint256 i = 0; i < tokens.length; i++) { + if (amounts[i] > 0) { + uint256 currentBalance = IERC20(tokens[i]).balanceOf(address(this)); + + // Verify repayment: current balance must be at least (initial balance + fee) + require( + currentBalance >= initialBalances[i] + _ceilFee(amounts[i], flashFeePpm), + "flash: repayment failed" + ); + + // Update cached balance + cachedUintBalances[i] = currentBalance; + } + } + } + + + /* ---------------------- + Conversion helpers + ---------------------- */ + + // Convert uint token amount -> internal 64.64 (floor). Uses ABDKMath64x64.divu which truncates. + function _uintToInternalFloor(uint256 amount, uint256 base) internal pure returns (int128) { + // internal = amount / base (as Q64.64) + return ABDKMath64x64.divu(amount, base); + } + + // Convert internal 64.64 -> uint token amount (floor). Uses ABDKMath64x64.mulu which floors the product. + function _internalToUintFloor(int128 internalAmount, uint256 base) internal pure returns (uint256) { + // uint = internal * base (floored) + return ABDKMath64x64.mulu(internalAmount, base); + } + + // Convert internal 64.64 -> uint token amount (ceiling). Rounds up to protect the pool. + function _internalToUintCeil(int128 internalAmount, uint256 base) internal pure returns (uint256) { + // Get the floor value first + uint256 floorValue = ABDKMath64x64.mulu(internalAmount, base); + + // Check if there was any fractional part by comparing to a reconstruction of the original + int128 reconstructed = ABDKMath64x64.divu(floorValue, base); + + // If reconstructed is less than original, there was a fractional part that was truncated + if (reconstructed < internalAmount) { + return floorValue + 1; + } + + return floorValue; + } + + /* ---------------------- + ERC20 helpers (minimal) + ---------------------- */ + + function _safeTransferFrom(address token, address from, address to, uint256 amt) internal { + IERC20(token).safeTransferFrom(from, to, amt); + } + + function _safeTransfer(address token, address to, uint256 amt) internal { + IERC20(token).safeTransfer(to, amt); + } + + /// @notice Helper to compute size metric (sum of all asset quantities) from internal balances + function _computeSizeMetric(int128[] memory qInternal_) private pure returns (int128) { + int128 total = int128(0); + for (uint i = 0; i < qInternal_.length; ) { + total = total.add(qInternal_[i]); + unchecked { i++; } + } + return total; + } + +} diff --git a/test/Counter.t.sol b/test/Counter.t.sol new file mode 100644 index 0000000..4831910 --- /dev/null +++ b/test/Counter.t.sol @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +import {Test} from "forge-std/Test.sol"; +import {Counter} from "../src/Counter.sol"; + +contract CounterTest is Test { + Counter public counter; + + function setUp() public { + counter = new Counter(); + counter.setNumber(0); + } + + function test_Increment() public { + counter.increment(); + assertEq(counter.number(), 1); + } + + function testFuzz_SetNumber(uint256 x) public { + counter.setNumber(x); + assertEq(counter.number(), x); + } +} diff --git a/test/LMSRStabilized.t.sol b/test/LMSRStabilized.t.sol new file mode 100644 index 0000000..dabac54 --- /dev/null +++ b/test/LMSRStabilized.t.sol @@ -0,0 +1,861 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import "forge-std/Test.sol"; +import "forge-std/console.sol"; +import "@openzeppelin/contracts/interfaces/IERC20Metadata.sol"; +import "../src/LMSRStabilized.sol"; + + +/// @notice Forge tests for LMSRStabilized +contract LMSRStabilizedTest is Test { + using LMSRStabilized for LMSRStabilized.State; + using ABDKMath64x64 for int128; + + LMSRStabilized.State internal s; + + int128 stdTradeSize; + int128 stdSlippage; + + + function setUp() public { + // 0.10% slippage when taking 1.00% of the assets + stdTradeSize = ABDKMath64x64.divu(100,10_000); + stdSlippage = ABDKMath64x64.divu(10,10_000); + } + + function initBalanced() internal { + int128[] memory q = new int128[](3); + q[0] = ABDKMath64x64.fromUInt(1_000_000); + q[1] = ABDKMath64x64.fromUInt(1_000_000); + q[2] = ABDKMath64x64.fromUInt(1_000_000); + s.init(q, stdTradeSize, stdSlippage); + } + + function initAlmostBalanced() internal { + int128[] memory q = new int128[](3); + q[0] = ABDKMath64x64.fromUInt(999_999); + q[1] = ABDKMath64x64.fromUInt(1_000_000); + q[2] = ABDKMath64x64.fromUInt(1_000_001); + s.init(q, stdTradeSize, stdSlippage); + } + + function initImbalanced() internal { + int128[] memory q = new int128[](4); + q[0] = ABDKMath64x64.fromUInt(1); + q[1] = ABDKMath64x64.fromUInt(1e9); + q[2] = ABDKMath64x64.fromUInt(1); + q[3] = ABDKMath64x64.divu(1, 1e9); + s.init(q, stdTradeSize, stdSlippage); + } + + + function testInitBalanced() public { + // Test 1: Balanced Pool Initialization + initBalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + // Verify slippage by performing asset swaps and checking price impact + int128 tradeAmount = mockQInternal[0].mul(stdTradeSize); + + // For a balanced pool, test asset 0 -> asset 1 swap + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(0, 1, tradeAmount, 0); + + // Verify amountIn and amountOut are reasonable + assertTrue(amountIn > 0, "amountIn should be positive"); + assertTrue(amountOut > 0, "amountOut should be positive"); + + // Calculate slippage = (initialPrice/finalPrice - 1) + // Compute e values dynamically for price ratio + int128 b = _computeB(mockQInternal); + int128[] memory eValues = _computeE(b, mockQInternal); + + // For balanced pool, initial price ratio is 1:1 + int128 initialRatio = eValues[0].div(eValues[1]); + + // Verify initial ratio for balanced pool is approximately 1:1 + assertTrue((initialRatio.sub(ABDKMath64x64.fromInt(1))).abs() < ABDKMath64x64.divu(1, 10000), + "Initial price ratio should be close to 1:1"); + + // After trade, the new e values would be different + int128 newE0 = eValues[0].mul(_exp(tradeAmount.div(b))); + int128 slippageRatio = newE0.div(eValues[0]).div(eValues[1].div(eValues[1])); + int128 slippage = slippageRatio.sub(ABDKMath64x64.fromInt(1)); + console2.log('slippage', slippage); + + // Slippage should be close to stdSlippage (within 1% relative error) + int128 relativeError = slippage.sub(stdSlippage).abs().div(stdSlippage); + assertLt(relativeError, ABDKMath64x64.divu(1, 100), "Balanced pool slippage error too high"); + } + + function testInitAlmostBalanced() public { + // Test 2: Almost Balanced Pool Initialization + initAlmostBalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(999_999); + mockQInternal[1] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_001); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + // Verify slippage for almost balanced pool + int128 tradeAmount = mockQInternal[0].mul(stdTradeSize); + + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(0, 1, tradeAmount, 0); + + // Verify amountIn and amountOut are reasonable + assertTrue(amountIn > 0, "amountIn should be positive"); + assertTrue(amountOut > 0, "amountOut should be positive"); + + // Compute e values dynamically for price ratio + int128 b = _computeB(mockQInternal); + int128[] memory eValues = _computeE(b, mockQInternal); + int128 initialRatio = eValues[0].div(eValues[1]); + int128 relDiff = (initialRatio.sub(ABDKMath64x64.fromInt(1))).abs(); + // Verify the initial ratio is close to but not exactly 1:1 + assertTrue(relDiff < ABDKMath64x64.divu(1, 1000), + "Initial ratio should be close to 1:1 for almost balanced pool"); + assertTrue(relDiff > ABDKMath64x64.divu(1, 10000000), + "Initial ratio should not be exactly 1:1 for almost balanced pool"); + + int128 newE0 = eValues[0].mul(_exp(tradeAmount.div(b))); + int128 slippageRatio = newE0.div(eValues[0]).div(eValues[1].div(eValues[1])); + int128 slippage = slippageRatio.sub(ABDKMath64x64.fromInt(1)); + console2.log('slippage', slippage); + int128 relativeError = slippage.sub(stdSlippage).abs().div(stdSlippage); + assertLt(relativeError, ABDKMath64x64.divu(1, 100), "Almost balanced pool slippage error too high"); + } + + function testInitImbalanced() public { + // Test 3: Imbalanced Pool Initialization + initImbalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](4); + mockQInternal[0] = ABDKMath64x64.fromUInt(1); + mockQInternal[1] = ABDKMath64x64.fromUInt(1e9); + mockQInternal[2] = ABDKMath64x64.fromUInt(1); + mockQInternal[3] = ABDKMath64x64.divu(1, 1e9); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + // For imbalanced pool, we need to try an "average" swap + // We'll use asset 0 -> asset 2 as it's more balanced than asset 0 -> asset 1 + int128 tradeAmount = mockQInternal[0].mul(stdTradeSize); + + // Compute e values dynamically for price ratio + int128 b = _computeB(mockQInternal); + int128[] memory eValues = _computeE(b, mockQInternal); + + // Verify the ratios between small and large assets is different + int128 initialRatio = eValues[0].div(eValues[3]); // Assets 0 and 2 match, and assets 1 and 3 match. 0 and 3 differ. + int128 relDiff = (initialRatio.sub(ABDKMath64x64.fromInt(1))).abs(); + // Verify initial ratio shows significant imbalance + assertTrue(relDiff != 0, "Initial ratio should show imbalance"); + + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(0, 2, tradeAmount, 0); + + // Verify amountIn and amountOut are reasonable + assertTrue(amountIn > 0, "amountIn should be positive"); + assertTrue(amountOut > 0, "amountOut should be positive"); + + int128 newE0 = eValues[0].mul(_exp(tradeAmount.div(b))); + int128 slippageRatio = newE0.div(eValues[0]).div(eValues[2].div(eValues[2])); + int128 slippage = slippageRatio.sub(ABDKMath64x64.fromInt(1)); + console2.log('slippage', slippage); + + // Since the imbalance is extreme, with one coin worth lots more than the others, the actual slippage for + // this swap is actually off by about 100% + // When we configure kappa, it is a best case slippage (worst case AMM loss) that only occurs with balanced + // assets + int128 relativeError = slippage.sub(stdSlippage).abs().div(stdSlippage); + console2.log('relative error', relativeError); + assertLt(relativeError, ABDKMath64x64.divu(100, 100), "Imbalanced pool slippage error too high"); + } + + function testRecentering() public { + // Recentering functionality has been removed since we no longer cache intermediate values + // This test is now a no-op but kept for API compatibility + initAlmostBalanced(); + + // Verify basic state is still functional + assertTrue(s.nAssets > 0, "State should still be initialized"); + assertTrue(s.kappa > int128(0), "Kappa should still be positive"); + } + + function testRescalingAfterDeposit() public { + // Initialize pool with almost balanced assets + initAlmostBalanced(); + + // Create initial asset quantities + int128[] memory initialQ = new int128[](3); + initialQ[0] = ABDKMath64x64.fromUInt(999_999); + initialQ[1] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[2] = ABDKMath64x64.fromUInt(1_000_001); + + // Update the state's cached qInternal + _updateCachedQInternal(initialQ); + + // Store initial parameters + int128 initialB = _computeB(initialQ); + int128 initialKappa = s.kappa; + + // Simulate a deposit by increasing all asset quantities by 50% + int128[] memory newQ = new int128[](s.nAssets); + for (uint i = 0; i < s.nAssets; i++) { + // Increase by 50% + newQ[i] = initialQ[i].mul(ABDKMath64x64.fromUInt(3).div(ABDKMath64x64.fromUInt(2))); // 1.5x + } + + // Apply the update for proportional change + s.updateForProportionalChange(newQ); + + // Verify that b has been rescaled proportionally + int128 newB = _computeB(s.qInternal); + int128 expectedRatio = ABDKMath64x64.fromUInt(3).div(ABDKMath64x64.fromUInt(2)); // 1.5x + int128 actualRatio = newB.div(initialB); + + int128 tolerance = ABDKMath64x64.divu(1, 1000); // 0.1% tolerance + assertTrue((actualRatio.sub(expectedRatio)).abs() < tolerance, "b did not scale proportionally after deposit"); + + // Verify kappa remained unchanged + assertTrue((s.kappa.sub(initialKappa)).abs() < tolerance, "kappa should not change after deposit"); + + // Verify slippage target is still met by performing a trade + int128 tradeAmount = s.qInternal[0].mul(stdTradeSize); + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(0, 1, tradeAmount, 0); + + // Verify computed swap amounts + assertTrue(amountIn > 0, "Swap amountIn should be positive"); + assertTrue(amountOut > 0, "Swap amountOut should be positive"); + // Verify amountOut is reasonable compared to amountIn (not a severe loss) + assertTrue(amountOut.div(amountIn) > ABDKMath64x64.divu(9, 10), "Swap should not incur severe loss"); + + int128[] memory eValues = _computeE(newB, s.qInternal); + int128 newE0 = eValues[0].mul(_exp(tradeAmount.div(newB))); + int128 slippageRatio = newE0.div(eValues[0]).div(eValues[1].div(eValues[1])); + int128 slippage = slippageRatio.sub(ABDKMath64x64.fromInt(1)); + console2.log('post-deposit slippage', slippage); + + int128 relativeError = slippage.sub(stdSlippage).abs().div(stdSlippage); + assertLt(relativeError, ABDKMath64x64.divu(1, 100), "Slippage target not met after deposit"); + } + + function testRescalingAfterWithdrawal() public { + // Initialize pool with almost balanced assets + initAlmostBalanced(); + + // Create initial asset quantities + int128[] memory initialQ = new int128[](3); + initialQ[0] = ABDKMath64x64.fromUInt(999_999); + initialQ[1] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[2] = ABDKMath64x64.fromUInt(1_000_001); + + // Update the state's cached qInternal + _updateCachedQInternal(initialQ); + + // Store initial parameters + int128 initialB = _computeB(initialQ); + int128 initialKappa = s.kappa; + + // Simulate a withdrawal by decreasing all asset quantities by 30% + int128[] memory newQ = new int128[](s.nAssets); + for (uint i = 0; i < s.nAssets; i++) { + // Decrease by 30% + newQ[i] = initialQ[i].mul(ABDKMath64x64.fromUInt(7).div(ABDKMath64x64.fromUInt(10))); // 0.7x + } + + // Apply the update for proportional change + s.updateForProportionalChange(newQ); + + // Verify that b has been rescaled proportionally + int128 newB = _computeB(s.qInternal); + int128 expectedRatio = ABDKMath64x64.fromUInt(7).div(ABDKMath64x64.fromUInt(10)); // 0.7x + int128 actualRatio = newB.div(initialB); + + int128 tolerance = ABDKMath64x64.divu(1, 1000); // 0.1% tolerance + assertTrue((actualRatio.sub(expectedRatio)).abs() < tolerance, "b did not scale proportionally after withdrawal"); + + // Verify kappa remained unchanged + assertTrue((s.kappa.sub(initialKappa)).abs() < tolerance, "kappa should not change after withdrawal"); + + // Verify slippage target is still met by performing a trade + int128 tradeAmount = s.qInternal[0].mul(stdTradeSize); + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(0, 1, tradeAmount, 0); + + // Verify computed swap amounts + assertTrue(amountIn > 0, "Swap amountIn should be positive"); + assertTrue(amountOut > 0, "Swap amountOut should be positive"); + // Verify amountOut is reasonable compared to amountIn (not a severe loss) + assertTrue(amountOut.div(amountIn) > ABDKMath64x64.divu(9, 10), "Swap should not incur severe loss"); + + int128[] memory eValues = _computeE(newB, s.qInternal); + int128 newE0 = eValues[0].mul(_exp(tradeAmount.div(newB))); + int128 slippageRatio = newE0.div(eValues[0]).div(eValues[1].div(eValues[1])); + int128 slippage = slippageRatio.sub(ABDKMath64x64.fromInt(1)); + console2.log('post-withdrawal slippage', slippage); + + int128 relativeError = slippage.sub(stdSlippage).abs().div(stdSlippage); + assertLt(relativeError, ABDKMath64x64.divu(1, 100), "Slippage target not met after withdrawal"); + } + + // --- tests probing numerical stability and boundary conditions --- + + /// @notice Recentering functionality has been removed - this test is now a no-op + function testRecenterShiftTooLargeReverts() public { + initAlmostBalanced(); + // Recentering has been removed, so this test now just verifies basic functionality + assertTrue(s.nAssets > 0, "State should still be initialized"); + } + + /// @notice limitPrice <= current price should revert (no partial fill) + function testLimitPriceRevertWhenAtOrBelowCurrent() public { + initBalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + // For balanced pool r0 = 1. Use limitPrice == 1 which should revert. + int128 tradeAmount = mockQInternal[0].mul(stdTradeSize); + + vm.expectRevert(bytes("LMSR: limitPrice <= current price")); + this.externalSwapAmountsForExactInput(0, 1, tradeAmount, ABDKMath64x64.fromInt(1)); + } + + /// @notice If e_j == 0 we should revert early to avoid div-by-zero + function testEJZeroReverts() public { + initBalanced(); + + // Create mock qInternal where asset 1 has zero quantity + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = int128(0); // Zero quantity for asset 1 + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + int128 tradeAmount = mockQInternal[0].mul(stdTradeSize); + + vm.expectRevert(bytes("LMSR: e_j==0")); + this.externalSwapAmountsForExactInput(0, 1, tradeAmount, 0); + } + + /// @notice swapAmountsForPriceLimit returns zero if limit equals current price + function testSwapAmountsForPriceLimitZeroWhenLimitEqualsPrice() public { + initBalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + // For balanced pool r0 = 1. swapAmountsForPriceLimit with limit==1 should be zero + vm.expectRevert('LMSR: limitPrice <= current price'); + this.externalSwapAmountsForPriceLimit(0, 1, ABDKMath64x64.fromInt(1)); + + // Try with a limit price slightly above 1, which should not revert + try this.externalSwapAmountsForPriceLimit(0, 1, ABDKMath64x64.fromInt(1).add(ABDKMath64x64.divu(1, 1000))) returns (int128 _amountIn, int128 _maxOut) { + // Verify that the returned values are reasonable + assertTrue(_amountIn > 0, "amountIn should be positive for valid limit price"); + assertTrue(_maxOut > 0, "maxOut should be positive for valid limit price"); + } catch { + fail("Should not revert with limit price > current price"); + } + } + + function externalSwapAmountsForPriceLimit(uint256 i, uint256 j, int128 limitPrice) external view + returns (int128, int128) { + return s.swapAmountsForPriceLimit(i, j, limitPrice); + } + + /// @notice Gas/throughput test: perform 100 alternating swaps between asset 0 and 1 + function testSwapGas() public { + // Initialize the almost-balanced pool + initAlmostBalanced(); + + // Create mock qInternal that we'll update through swaps + int128[] memory currentQ = new int128[](3); + currentQ[0] = ABDKMath64x64.fromUInt(999_999); + currentQ[1] = ABDKMath64x64.fromUInt(1_000_000); + currentQ[2] = ABDKMath64x64.fromUInt(1_000_001); + + // Update the state's cached qInternal + _updateCachedQInternal(currentQ); + + // Perform 100 swaps, alternating between asset 0 -> 1 and 1 -> 0 + for (uint256 iter = 0; iter < 100; iter++) { + uint256 from = (iter % 2 == 0) ? 0 : 1; + uint256 to = (from == 0) ? 1 : 0; + + // Use standard trade size applied to the 'from' asset's current quantity + int128 tradeAmount = s.qInternal[from].mul(stdTradeSize); + + // Compute swap amounts and apply to state + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(from, to, tradeAmount, 0); + + // applySwap now updates the internal qInternal directly + s.applySwap(from, to, amountIn, amountOut); + } + } + + /// @notice Extremely large a that makes a/b exceed expLimit should revert + function testAmountOutABOverflowReverts() public { + initBalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + int128 b = _computeB(mockQInternal); + // Pick a such that a/b = 33 (expLimit is 32). a = b * 33 + int128 aOverB_target = ABDKMath64x64.fromInt(33); + int128 a = b.mul(aOverB_target); + + vm.expectRevert(bytes("LMSR: a/b too large (would overflow exp)")); + this.externalSwapAmountsForExactInput(0, 1, a, 0); + } + + // Helper function to compute b from qInternal (either from provided array or state) + function _computeB(int128[] memory qInternal) internal view returns (int128) { + int128 sizeMetric = _computeSizeMetric(qInternal); + require(sizeMetric > int128(0), "LMSR: size metric zero"); + return s.kappa.mul(sizeMetric); + } + + // Overload that uses state's cached qInternal + function _computeB() internal view returns (int128) { + int128 sizeMetric = _computeSizeMetric(s.qInternal); + require(sizeMetric > int128(0), "LMSR: size metric zero"); + return s.kappa.mul(sizeMetric); + } + + // Helper function to compute size metric (sum of all asset quantities) + function _computeSizeMetric(int128[] memory qInternal) internal pure returns (int128) { + int128 total = int128(0); + for (uint i = 0; i < qInternal.length; ) { + total = total.add(qInternal[i]); + unchecked { i++; } + } + return total; + } + + // Helper function to update the state's cached qInternal + function _updateCachedQInternal(int128[] memory mockQInternal) internal { + // First ensure qInternal array exists with the right size + if (s.qInternal.length != mockQInternal.length) { + s.qInternal = new int128[](mockQInternal.length); + } + + // Copy values from mockQInternal to state's qInternal + for (uint i = 0; i < mockQInternal.length; ) { + s.qInternal[i] = mockQInternal[i]; + unchecked { i++; } + } + } + + // Helper function to compute M and Z dynamically + function _computeMAndZ(int128 b, int128[] memory qInternal) internal pure returns (int128 M, int128 Z) { + require(qInternal.length > 0, "LMSR: no assets"); + + // Compute y_i = q_i / b for numerical stability + int128[] memory y = new int128[](qInternal.length); + for (uint i = 0; i < qInternal.length; ) { + y[i] = qInternal[i].div(b); + unchecked { i++; } + } + + // Find max y for centering (M = maxY) + M = y[0]; + for (uint i = 1; i < qInternal.length; ) { + if (y[i] > M) M = y[i]; + unchecked { i++; } + } + + // Compute Z = sum of exp(z_i) where z_i = y_i - M + Z = int128(0); + for (uint i = 0; i < qInternal.length; ) { + int128 z_i = y[i].sub(M); + int128 e_i = _exp(z_i); + Z = Z.add(e_i); + unchecked { i++; } + } + } + + // Helper function to compute all e[i] = exp(z[i]) values dynamically + function _computeE(int128 b, int128[] memory qInternal) internal pure returns (int128[] memory e) { + (int128 M, ) = _computeMAndZ(b, qInternal); + e = new int128[](qInternal.length); + + for (uint i = 0; i < qInternal.length; ) { + int128 y_i = qInternal[i].div(b); + int128 z_i = y_i.sub(M); + e[i] = _exp(z_i); + unchecked { i++; } + } + } + + // Helper function to calculate exp (copied from LMSRStabilized library) + function _exp(int128 x) internal pure returns (int128) { + return ABDKMath64x64.exp(x); + } + + // External helper function that wraps swapAmountsForExactInput to properly handle reverts in tests + function externalSwapAmountsForExactInput( + uint i, + uint j, + int128 a, + int128 limitPrice + ) external view returns (int128 amountIn, int128 amountOut) { + return s.swapAmountsForExactInput(i, j, a, limitPrice); + } + + // External helper function that wraps recenterIfNeeded to properly handle reverts in tests + function externalRecenterIfNeeded() external { + // Recentering has been removed - this is now a no-op + } + + // External helper function that wraps applySwap to properly handle reverts in tests + function externalApplySwap( + uint i, + uint j, + int128 amountIn, + int128 amountOut + ) external { + s.applySwap(i, j, amountIn, amountOut); + } + + // Small helper: convert a Q64.64 int128 into micro-units (value * 1e6) as an int256 for readable logging. + // Example: if x represents 0.001 (Q64.64), _toMicro(x) will return ~1000. + function _toMicro(int128 x) internal pure returns (int256) { + int256 ONE = int256(uint256(0x10000000000000000)); // 2^64 + return (int256(x) * 1_000_000) / ONE; + } + + /// @notice Test that applySwap correctly validates swap parameters and updates qInternal + function testApplySwap() public { + // Initialize with balanced assets + initBalanced(); + + // Create mock qInternal for testing + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(mockQInternal); + + // Save original values for comparison + int128 originalQ0 = s.qInternal[0]; + int128 originalQ1 = s.qInternal[1]; + + // Calculate swap amounts from asset 0 to asset 1 + int128 tradeAmount = mockQInternal[0].mul(stdTradeSize); + + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(0, 1, tradeAmount, 0); + + // Verify basic swap calculation worked + assertTrue(amountIn > 0, "amountIn should be positive"); + assertTrue(amountOut > 0, "amountOut should be positive"); + + // Apply the swap - should not revert for valid inputs + s.applySwap(0, 1, amountIn, amountOut); + + // Verify qInternal is correctly updated + // Input asset should increase by amountIn + assertEq(s.qInternal[0], originalQ0.add(amountIn), "qInternal[0] should be updated"); + // Output asset should decrease by amountOut + assertEq(s.qInternal[1], originalQ1.sub(amountOut), "qInternal[1] should be updated"); + } + + /// @notice Test path independence by comparing direct vs indirect swaps + function testPathIndependence() public { + // Start with a balanced pool + initBalanced(); + + // Create initial quantities + int128[] memory initialQValues = new int128[](s.nAssets); + initialQValues[0] = ABDKMath64x64.fromUInt(1_000_000); + initialQValues[1] = ABDKMath64x64.fromUInt(1_000_000); + initialQValues[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(initialQValues); + + // Test path independence by computing swap outcomes without state changes + int128 directSwapAmount = initialQValues[0].mul(stdTradeSize); + + // Store a backup of the original values to restore between swaps + int128[] memory backupQ = new int128[](s.nAssets); + for (uint i = 0; i < s.nAssets; i++) { + backupQ[i] = s.qInternal[i]; + } + + // Path 1: Direct swap from asset 0 to asset 2 + (int128 directAmountIn, int128 directAmountOut) = s.swapAmountsForExactInput(0, 2, directSwapAmount, 0); + + // Restore original state for second path + _updateCachedQInternal(backupQ); + + // Path 2: Swap from asset 0 to asset 1, then from asset 1 to asset 2 + (int128 indirectAmountIn1, int128 indirectAmountOut1) = s.swapAmountsForExactInput(0, 1, directSwapAmount, 0); + + // Update state for second leg of indirect path + s.qInternal[0] = s.qInternal[0].sub(indirectAmountIn1); + s.qInternal[1] = s.qInternal[1].add(indirectAmountOut1); + + // Second swap: asset 1 -> asset 2 + (int128 indirectAmountIn2, int128 indirectAmountOut2) = s.swapAmountsForExactInput(1, 2, indirectAmountOut1, 0); + + // The path independence property isn't perfect due to discrete swap mechanics, + // but the difference should be within reasonable bounds + console2.log("Direct swap output:"); + console2.logInt(directAmountOut); + console2.log("Indirect swap total output:"); + console2.logInt(indirectAmountOut2); + + // Basic verification that both paths produce positive outputs + assertTrue(directAmountOut > 0, "Direct swap should produce positive output"); + assertTrue(indirectAmountOut2 > 0, "Indirect swap should produce positive output"); + } + + /// @notice Test round-trip trades to verify near-zero slippage + function testRoundTripTradesAcrossAllPools() public { + // Test with balanced pool only since we removed state caching + initBalanced(); + + // Create mock qInternal + int128[] memory initialQ = new int128[](3); + initialQ[0] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[1] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(initialQ); + + console2.log("Testing round-trip trades for balanced pool"); + + // Use standard trade size + int128 tradeAmount = s.qInternal[0].mul(stdTradeSize); + + // Step 1: Swap asset 0 -> asset 1 + (int128 amountIn1, int128 amountOut1) = s.swapAmountsForExactInput(0, 1, tradeAmount, 0); + + // Update quantities for step 2 + s.qInternal[0] = s.qInternal[0].sub(amountIn1); + s.qInternal[1] = s.qInternal[1].add(amountOut1); + + // Step 2: Swap back asset 1 -> asset 0 + (int128 amountIn2, int128 amountOut2) = s.swapAmountsForExactInput(1, 0, amountOut1, 0); + + // Calculate round-trip slippage: (initial amount - final amount) / initial amount + int128 roundTripSlippage = (amountIn1.sub(amountOut2)).div(amountIn1); + + console2.log("Round-trip slippage (micro-units):"); + console2.logInt(_toMicro(roundTripSlippage)); + + // Verify round-trip slippage is reasonable + int128 tolerance = ABDKMath64x64.divu(1, 100000); // 0.001% tolerance + assertLt(roundTripSlippage.abs(), tolerance, "Round-trip slippage should be near zero"); + } + + /// @notice Test that slippage is approximately equal in both directions for small swaps + function testBidirectionalSlippageSymmetry() public { + // Initialize with balanced assets for clearest slippage measurement + initBalanced(); + + // Create mock qInternal + int128[] memory initialQ = new int128[](3); + initialQ[0] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[1] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(initialQ); + + // Use small trade size for clear slippage measurement + int128 tradeSize = ABDKMath64x64.divu(5, 10_000); // 0.05% of pool + int128 tradeAmount0 = s.qInternal[0].mul(tradeSize); + int128 tradeAmount1 = s.qInternal[1].mul(tradeSize); + + // Store original state to restore between tests + int128[] memory backupQ = new int128[](s.nAssets); + for (uint i = 0; i < s.nAssets; i++) { + backupQ[i] = s.qInternal[i]; + } + + // First direction: asset 0 -> asset 1 + (int128 amountIn0to1, int128 amountOut0to1) = s.swapAmountsForExactInput(0, 1, tradeAmount0, 0); + + // Restore original state + _updateCachedQInternal(backupQ); + + // Second direction: asset 1 -> asset 0 + (int128 amountIn1to0, int128 amountOut1to0) = s.swapAmountsForExactInput(1, 0, tradeAmount1, 0); + + console2.log("0->1 swap amountIn:"); + console2.logInt(amountIn0to1); + console2.log("0->1 swap amountOut:"); + console2.logInt(amountOut0to1); + console2.log("1->0 swap amountIn:"); + console2.logInt(amountIn1to0); + console2.log("1->0 swap amountOut:"); + console2.logInt(amountOut1to0); + + // For balanced pools, the swap ratios should be approximately symmetric + int128 ratio0to1 = amountOut0to1.div(amountIn0to1); + int128 ratio1to0 = amountOut1to0.div(amountIn1to0); + + // Calculate relative difference between the ratios + int128 ratioDifference = (ratio0to1.sub(ratio1to0)).abs(); + int128 relativeRatioDiff = ratioDifference.div(ratio0to1.add(ratio1to0).div(ABDKMath64x64.fromInt(2))); + + console2.log("Relative ratio difference (micro-units):"); + console2.logInt(_toMicro(relativeRatioDiff)); + + // Assert that the relative difference between ratios is small + int128 tolerance = ABDKMath64x64.divu(5, 100); // 5% tolerance + assertLt(relativeRatioDiff, tolerance, + "Swap ratios should be approximately equal in both directions"); + } + + /// @notice Test that basic swap functionality works across multiple operations + function testZConsistencyAfterMultipleSwaps() public { + // Initialize with balanced assets + initBalanced(); + + // Create mock qInternal that we'll update through swaps + int128[] memory initialQ = new int128[](3); + initialQ[0] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[1] = ABDKMath64x64.fromUInt(1_000_000); + initialQ[2] = ABDKMath64x64.fromUInt(1_000_000); + + // Update the state's cached qInternal + _updateCachedQInternal(initialQ); + + // Perform multiple swaps in different directions + for (uint i = 0; i < 5; i++) { + // Swap from asset i%3 to asset (i+1)%3 + uint from = i % 3; + uint to = (i + 1) % 3; + + int128 tradeAmount = s.qInternal[from].mul(stdTradeSize); + + (int128 amountIn, int128 amountOut) = s.swapAmountsForExactInput(from, to, tradeAmount, 0); + + // Apply swap to update internal state + s.applySwap(from, to, amountIn, amountOut); + + // Basic validation that swap worked + assertTrue(amountIn > 0, "amountIn should be positive"); + assertTrue(amountOut > 0, "amountOut should be positive"); + } + } + + // --- New tests for single-token mint/burn helpers --- + + /// @notice Basic sanity check for swapAmountsForMint: small single-token input + function testSwapAmountsForMintBasic() public { + initBalanced(); + + // Use a small single-token input (stdTradeSize fraction of asset 0) + int128 a = s.qInternal[0].mul(stdTradeSize); + + (int128 consumed, int128 lpIncrease) = s.swapAmountsForMint(0, a); + + // consumed must be non-negative and <= provided a (partial-fill allowed) + assertTrue(consumed > 0, "consumed should be positive"); + assertTrue(consumed <= a, "consumed should not exceed provided input"); + + // lpIncrease should be positive + assertTrue(lpIncrease > 0, "lpIncrease should be positive"); + } + + /// @notice Large input for swapAmountsForMint should return a valid partial fill (consumed <= provided) + function testSwapAmountsForMintLargeInputPartial() public { + initAlmostBalanced(); + + // Provide a large input far above stdTradeSize to exercise cap logic + int128 a = s.qInternal[0].mul(ABDKMath64x64.fromUInt(1000)); // 1000x one-asset quantity + + (int128 consumed, int128 lpIncrease) = s.swapAmountsForMint(0, a); + + // Should not consume more than provided + assertTrue(consumed <= a, "consumed must be <= provided"); + + // If nothing could be consumed, the helper should revert earlier; otherwise positive + assertTrue(consumed > 0, "consumed should be positive for large input in normal pools"); + assertTrue(lpIncrease > 0, "lpIncrease should be positive for large input"); + } + + /// @notice Basic swapAmountsForBurn sanity: small alpha should return positive single-asset payout + function testSwapAmountsForBurnBasic() public { + initBalanced(); + + // Burn alpha fraction of pool + int128 alpha = ABDKMath64x64.divu(1, 100); // 1% + int128 S = _computeSizeMetric(s.qInternal); + + (int128 payout, int128 burned) = s.swapAmountsForBurn(0, alpha); + + // burned should equal alpha * S + assertEq(burned, alpha.mul(S), "burned size-metric mismatch"); + + // payout should be positive + assertTrue(payout > 0, "payout must be positive for balanced pool burn"); + } + + /// @notice If some assets have zero quantity, burn should skip them but still return payout when possible + function testSwapAmountsForBurnWithZeroAsset() public { + initBalanced(); + + // Make asset 1 empty; others non-zero + int128[] memory mockQInternal = new int128[](3); + mockQInternal[0] = ABDKMath64x64.fromUInt(1_000_000); + mockQInternal[1] = int128(0); // zero + mockQInternal[2] = ABDKMath64x64.fromUInt(1_000_000); + _updateCachedQInternal(mockQInternal); + + int128 alpha = ABDKMath64x64.divu(1, 100); // 1% + (int128 payout, int128 burned) = s.swapAmountsForBurn(0, alpha); + + // Should still burn the size metric + int128 S = _computeSizeMetric(mockQInternal); + assertEq(burned, alpha.mul(S), "burned size-metric mismatch with zero asset present"); + + // Payout should be at least the direct redeemed portion (alpha * q_i) + assertTrue(payout >= alpha.mul(mockQInternal[0]), "payout should be >= direct redeemed portion"); + + // Payout must be positive + assertTrue(payout > 0, "payout must be positive even when one asset is zero"); + } + +} diff --git a/test/MockERC20.sol b/test/MockERC20.sol new file mode 100644 index 0000000..2dc7b78 --- /dev/null +++ b/test/MockERC20.sol @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MockERC20 is ERC20 { + uint8 private immutable _decimals; + + constructor(string memory name, string memory symbol, uint8 decimals_) ERC20(name, symbol) {_decimals = decimals_;} + + function decimals() public view virtual override returns (uint8) {return _decimals;} + function mint(address account, uint256 amount) external {_mint(account, amount);} + function burn(address account, uint256 amount) external {_burn(account, amount);} +} diff --git a/test/PartyPool.t.sol b/test/PartyPool.t.sol new file mode 100644 index 0000000..8f7abd7 --- /dev/null +++ b/test/PartyPool.t.sol @@ -0,0 +1,1423 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.30; + +import "forge-std/Test.sol"; +import "@abdk/ABDKMath64x64.sol"; +import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; +import "../src/PartyPool.sol"; + +// Import the flash callback interface +import "../src/IPartyFlashCallback.sol"; + +/// @notice Test contract that implements the flash callback for testing flash loans +contract FlashBorrower is IPartyFlashCallback { + enum Action { + NORMAL, // Normal repayment + REPAY_NONE, // Don't repay anything + REPAY_PARTIAL, // Repay less than required + REPAY_NO_FEE, // Repay only the principal without fee + REPAY_EXACT, // Repay exactly the required amount + REPAY_EXTRA // Repay more than required (donation) + } + + Action public action; + address public pool; + address public recipient; + address[] public tokens; + + constructor(address _pool, address[] memory _tokens) { + pool = _pool; + tokens = _tokens; + } + + function setAction(Action _action, address _recipient) external { + action = _action; + recipient = _recipient; + } + + function flash(uint256[] memory amounts) external { + PartyPool(pool).flash(recipient, amounts, ""); + } + + function partyFlashCallback( + uint256[] memory loanAmounts, + uint256[] memory repaymentAmounts, + bytes calldata /* data */ + ) external override { + require(msg.sender == pool, "Callback not called by pool"); + + if (action == Action.NORMAL || action == Action.REPAY_EXTRA) { + // Normal or extra repayment - transfer required amounts back to pool + for (uint256 i = 0; i < loanAmounts.length; i++) { + if (loanAmounts[i] > 0) { + uint256 repaymentAmount = repaymentAmounts[i]; + + // For REPAY_EXTRA, add 1 to each repayment + if (action == Action.REPAY_EXTRA) { + repaymentAmount += 1; + } + + // Transfer from recipient back to pool + TestERC20(tokens[i]).transferFrom( + recipient, + pool, + repaymentAmount + ); + } + } + } else if (action == Action.REPAY_PARTIAL) { + // Repay half of the required amounts + for (uint256 i = 0; i < loanAmounts.length; i++) { + if (loanAmounts[i] > 0) { + uint256 partialRepayment = repaymentAmounts[i] / 2; + TestERC20(tokens[i]).transferFrom( + recipient, + pool, + partialRepayment + ); + } + } + } else if (action == Action.REPAY_NO_FEE) { + // Repay only the principal without fee + for (uint256 i = 0; i < loanAmounts.length; i++) { + if (loanAmounts[i] > 0) { + TestERC20(tokens[i]).transferFrom( + recipient, + pool, + loanAmounts[i] + ); + } + } + } else if (action == Action.REPAY_EXACT) { + // Repay exactly what was required + for (uint256 i = 0; i < loanAmounts.length; i++) { + if (loanAmounts[i] > 0) { + TestERC20(tokens[i]).transferFrom( + recipient, + pool, + repaymentAmounts[i] + ); + } + } + } + // For REPAY_NONE, do nothing (don't repay) + } +} + +/// @notice Minimal ERC20 token for tests with an external mint function. +contract TestERC20 is ERC20 { + constructor(string memory name_, string memory symbol_, uint256 initialSupply) ERC20(name_, symbol_) { + if (initialSupply > 0) { + _mint(msg.sender, initialSupply); + } + } + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } + + // Expose convenient approve helper for tests (not necessary but handy) + function approveMax(address spender) external { + _approve(msg.sender, spender, type(uint256).max); + } +} + +/// @notice Tests for PartyPool wrapper functionality: mint/burn/swap behavior, edge-cases and protections. +contract PartyPoolTest is Test { + using ABDKMath64x64 for int128; + + TestERC20 token0; + TestERC20 token1; + TestERC20 token2; + TestERC20 token3; + TestERC20 token4; + TestERC20 token5; + TestERC20 token6; + TestERC20 token7; + TestERC20 token8; + TestERC20 token9; + PartyPool pool; + PartyPool pool10; + + address alice; + address bob; + + // Common parameters + int128 tradeFrac; + int128 targetSlippage; + + uint256 constant INIT_BAL = 1_000_000; // initial token units for each token (internal==amount when base==1) + uint256 constant BASE = 1; // use base=1 so internal amounts correspond to raw integers (Q64.64 units) + + function setUp() public { + alice = address(0xA11ce); + bob = address(0xB0b); + + // Deploy three ERC20 test tokens and mint initial supplies to this test contract for initial deposit + token0 = new TestERC20("T0", "T0", 0); + token1 = new TestERC20("T1", "T1", 0); + token2 = new TestERC20("T2", "T2", 0); + token3 = new TestERC20("T3", "T3", 0); + token4 = new TestERC20("T4", "T4", 0); + token5 = new TestERC20("T5", "T5", 0); + token6 = new TestERC20("T6", "T6", 0); + token7 = new TestERC20("T7", "T7", 0); + token8 = new TestERC20("T8", "T8", 0); + token9 = new TestERC20("T9", "T9", 0); + + // Mint initial balances to the test contract to perform initial deposit + token0.mint(address(this), INIT_BAL); + token1.mint(address(this), INIT_BAL); + token2.mint(address(this), INIT_BAL); + token3.mint(address(this), INIT_BAL); + token4.mint(address(this), INIT_BAL); + token5.mint(address(this), INIT_BAL); + token6.mint(address(this), INIT_BAL); + token7.mint(address(this), INIT_BAL); + token8.mint(address(this), INIT_BAL); + token9.mint(address(this), INIT_BAL); + + // Configure LMSR parameters similar to other tests: trade size 1% of asset -> 0.01, slippage 0.001 + tradeFrac = ABDKMath64x64.divu(100, 10_000); // 0.01 + targetSlippage = ABDKMath64x64.divu(10, 10_000); // 0.001 + + // Build arrays for pool constructor + address[] memory tokens = new address[](3); + tokens[0] = address(token0); + tokens[1] = address(token1); + tokens[2] = address(token2); + + uint256[] memory bases = new uint256[](3); + bases[0] = BASE; + bases[1] = BASE; + bases[2] = BASE; + + // Deploy pool with a small fee to test fee-handling paths (use 1000 ppm = 0.1%) + uint256 feePpm = 1000; + + pool = new PartyPool("LP", "LP", tokens, bases, tradeFrac, targetSlippage, feePpm, feePpm); + + // Transfer initial deposit amounts into pool before initial mint (pool expects tokens already in contract) + // We deposit equal amounts INIT_BAL for each token + token0.transfer(address(pool), INIT_BAL); + token1.transfer(address(pool), INIT_BAL); + token2.transfer(address(pool), INIT_BAL); + + // Perform initial mint (initial deposit); receiver is this contract + pool.mint(address(0), address(this), 0, 0); + + // Set up pool10 with 10 tokens + address[] memory tokens10 = new address[](10); + tokens10[0] = address(token0); + tokens10[1] = address(token1); + tokens10[2] = address(token2); + tokens10[3] = address(token3); + tokens10[4] = address(token4); + tokens10[5] = address(token5); + tokens10[6] = address(token6); + tokens10[7] = address(token7); + tokens10[8] = address(token8); + tokens10[9] = address(token9); + + uint256[] memory bases10 = new uint256[](10); + for (uint i = 0; i < 10; i++) { + bases10[i] = BASE; + } + + pool10 = new PartyPool("LP10", "LP10", tokens10, bases10, tradeFrac, targetSlippage, feePpm, feePpm); + + // Mint additional tokens for pool10 initial deposit + token0.mint(address(this), INIT_BAL); + token1.mint(address(this), INIT_BAL); + token2.mint(address(this), INIT_BAL); + token3.mint(address(this), INIT_BAL); + token4.mint(address(this), INIT_BAL); + token5.mint(address(this), INIT_BAL); + token6.mint(address(this), INIT_BAL); + token7.mint(address(this), INIT_BAL); + token8.mint(address(this), INIT_BAL); + token9.mint(address(this), INIT_BAL); + + // Transfer initial deposit amounts into pool10 + token0.transfer(address(pool10), INIT_BAL); + token1.transfer(address(pool10), INIT_BAL); + token2.transfer(address(pool10), INIT_BAL); + token3.transfer(address(pool10), INIT_BAL); + token4.transfer(address(pool10), INIT_BAL); + token5.transfer(address(pool10), INIT_BAL); + token6.transfer(address(pool10), INIT_BAL); + token7.transfer(address(pool10), INIT_BAL); + token8.transfer(address(pool10), INIT_BAL); + token9.transfer(address(pool10), INIT_BAL); + + // Perform initial mint for pool10 + pool10.mint(address(0), address(this), 0, 0); + + // For later tests we will mint tokens to alice/bob as needed + token0.mint(alice, INIT_BAL); + token1.mint(alice, INIT_BAL); + token2.mint(alice, INIT_BAL); + token3.mint(alice, INIT_BAL); + token4.mint(alice, INIT_BAL); + token5.mint(alice, INIT_BAL); + token6.mint(alice, INIT_BAL); + token7.mint(alice, INIT_BAL); + token8.mint(alice, INIT_BAL); + token9.mint(alice, INIT_BAL); + + token0.mint(bob, INIT_BAL); + token1.mint(bob, INIT_BAL); + token2.mint(bob, INIT_BAL); + token3.mint(bob, INIT_BAL); + token4.mint(bob, INIT_BAL); + token5.mint(bob, INIT_BAL); + token6.mint(bob, INIT_BAL); + token7.mint(bob, INIT_BAL); + token8.mint(bob, INIT_BAL); + token9.mint(bob, INIT_BAL); + } + + /// @notice Basic sanity: initial mint should have produced LP tokens for this contract and the pool holds tokens. + function testInitialMintAndLP() public view { + uint256 totalLp = pool.totalSupply(); + assertTrue(totalLp > 0, "Initial LP supply should be > 0"); + + // Pool should hold the initial token balances + assertEq(token0.balanceOf(address(pool)), INIT_BAL); + assertEq(token1.balanceOf(address(pool)), INIT_BAL); + assertEq(token2.balanceOf(address(pool)), INIT_BAL); + } + + /// @notice If a caller requests to mint a very small LP amount that results in zero actual LP minted, + /// the call should revert with "mint: zero LP minted" to protect the pool. + function testProportionalMintZeroLpReverts() public { + // Attempt to request a tiny LP amount (1) and expect revert because calculated actualLpToMint will be zero + + // Approve pool to transfer tokens on alice's behalf + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + token1.approve(address(pool), type(uint256).max); + token2.approve(address(pool), type(uint256).max); + + vm.expectRevert(bytes("mint: zero LP amount")); + pool.mint(alice, alice, 0, 0); + vm.stopPrank(); + } + + /// @notice If a caller requests to mint a very small LP amount (1 wei) the pool should + /// honor the request (or revert only for 0 requests). We must ensure the pool-rounding + /// does not undercharge (no value extraction). This test verifies the request succeeds + /// and that computed deposits are at least the proportional floor (ceil >= floor). + function testProportionalMintOneWeiSucceedsAndProtectsPool() public { + // Request a tiny LP amount (1 wei). Approve pool to transfer tokens on alice's behalf. + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + token1.approve(address(pool), type(uint256).max); + token2.approve(address(pool), type(uint256).max); + + // Inspect the deposit amounts that the pool will require (these are rounded up) + uint256[] memory deposits = pool.computeMintAmounts(1); + + // Basic sanity: deposits array length must match token count and not all zero necessarily + assertEq(deposits.length, 3); + + // Compute the floor-proportional amounts for comparison: floor(lp * bal / totalLp) + uint256 totalLp = pool.totalSupply(); + for (uint i = 0; i < deposits.length; i++) { + uint256 bal = IERC20(pool.allTokens()[i]).balanceOf(address(pool)); + uint256 floorProportional = (1 * bal) / totalLp; // floor + // Ceil (deposit) must be >= floor (pool protected) + assertTrue(deposits[i] >= floorProportional, "deposit must not be less than floor proportion"); + } + + // Perform the mint — it should succeed for a 1 wei request (pool uses ceil to protect itself) + pool.mint(alice, alice, 1, 0); + + // After mint, alice should have received at least 1 wei of LP + assertTrue(pool.balanceOf(alice) >= 1, "Alice should receive at least 1 wei LP"); + + vm.stopPrank(); + } + + /// @notice Ensure very-small proportional mints do not enable value extraction: + /// i.e. the depositor should not pay less underlying value per LP than existing LP holders. + function testNoExtraValueExtractionForTinyMint() public { + // Prepare: approve and snapshot pool state + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + token1.approve(address(pool), type(uint256).max); + token2.approve(address(pool), type(uint256).max); + + // Snapshot pool totals (simple value metric = sum of token uint balances since base==1 in tests) + address[] memory toks = pool.allTokens(); + uint256 n = toks.length; + uint256 poolValueBefore = 0; + for (uint i = 0; i < n; i++) { + poolValueBefore += IERC20(toks[i]).balanceOf(address(pool)); + } + uint256 totalLpBefore = pool.totalSupply(); + + // Compute required deposits and perform mint for 1 wei + uint256[] memory deposits = pool.computeMintAmounts(1); + + // Sum deposits as deposited_value + uint256 depositedValue = 0; + for (uint i = 0; i < n; i++) { + depositedValue += deposits[i]; + } + + // Execute mint; it may revert if actualLpToMint == 0 but for 1 wei we expect it to succeed per design. + pool.mint(alice, alice, 1, 0); + + // Observe minted LP + uint256 totalLpAfter = pool.totalSupply(); + require(totalLpAfter >= totalLpBefore, "invariant: total LP cannot decrease"); + uint256 minted = totalLpAfter - totalLpBefore; + require(minted > 0, "sanity: minted should be > 0 for this test"); + + // Economic invariant check: + // depositedValue / minted >= poolValueBefore / totalLpBefore + // Rearranged (to avoid fractional math): depositedValue * totalLpBefore >= poolValueBefore * minted + // Use >= to allow the pool to charge equal-or-more value per LP (protects against extraction). + bool ok; + // Guard against zero-totalLP (shouldn't happen because pool initialised in setUp) + if (totalLpBefore == 0) { + ok = true; + } else { + ok = (depositedValue * totalLpBefore) >= (poolValueBefore * minted); + } + + assertTrue(ok, "Economic invariant violated: depositor paid less value per LP than existing holders"); + + vm.stopPrank(); + } + + /// @notice computeMintAmounts should round up deposit amounts to protect the pool. + function testComputeMintAmountsRoundingUp() public view { + uint256 totalLp = pool.totalSupply(); + assertTrue(totalLp > 0, "precondition: total supply > 0"); + + // Request half of LP supply + uint256 want = totalLp / 2; + uint256[] memory deposits = pool.computeMintAmounts(want); + + // We expect each deposit to be roughly half the pool balance, but due to rounding up it should satisfy: + // deposits[i] * 2 >= cached balance (i.e., rounding up) + for (uint i = 0; i < deposits.length; i++) { + uint256 poolBal = IERC20(pool.allTokens()[i]).balanceOf(address(pool)); + // deposit * 2 should be at least poolBal (protecting pool by rounding up) + assertTrue(deposits[i] * 2 >= poolBal || deposits[i] * 2 + 1 >= poolBal, "deposit rounding up expected"); + } + } + + /// @notice Burning all underlying assets should redeem all LP and leave totalSupply == 0. + function testBurnFullRedemption() public { + uint256 totalLp = pool.totalSupply(); + assertTrue(totalLp > 0, "precondition: LP > 0"); + + // Compute amounts required to redeem entire supply (should be current balances) + uint256[] memory withdrawAmounts = pool.computeBurnAmounts(totalLp); + + // Sanity: withdrawAmounts should equal pool balances (or very close due to rounding) + for (uint i = 0; i < withdrawAmounts.length; i++) { + uint256 poolBal = IERC20(pool.allTokens()[i]).balanceOf(address(pool)); + // withdrawAmounts should not exceed pool balance + assertTrue(withdrawAmounts[i] <= poolBal, "withdraw amount cannot exceed pool balance"); + } + + // Burn by sending LP tokens from this contract (which holds initial LP from setUp) + // Call burn(payer=this, receiver=bob, lpAmount=totalLp) + pool.burn(address(this), bob, totalLp, 0); + + // After burning entire pool, totalSupply should be zero or very small (we expect zero since we withdrew all) + assertEq(pool.totalSupply(), 0); + + // Bob should have received the withdrawn tokens + for (uint i = 0; i < withdrawAmounts.length; i++) { + assertTrue(IERC20(pool.allTokens()[i]).balanceOf(bob) >= withdrawAmounts[i], "Bob should receive withdrawn tokens"); + } + } + + /// @notice swap should transfer input+fee from payer, send output to receiver, and not exceed maxAmountIn. + function testSwapExactInputWithFee() public { + // Use alice as payer and bob as receiver + uint256 maxIn = 10_000; + + // Ensure alice has tokens and approves pool + vm.prank(alice); + token0.approve(address(pool), type(uint256).max); + + uint256 balAliceBefore = token0.balanceOf(alice); + uint256 balPoolBefore = token0.balanceOf(address(pool)); + uint256 balReceiverBefore = token1.balanceOf(bob); + + // Execute swap: token0 -> token1 + vm.prank(alice); + (uint256 amountInUsed, uint256 amountOut) = pool.swap(alice, bob, 0, 1, maxIn, 0, 0); + + // Amounts should be positive and not exceed provided max + assertTrue(amountInUsed > 0, "expected some input used"); + assertTrue(amountOut > 0, "expected some output returned"); + assertTrue(amountInUsed <= maxIn, "used input must not exceed max"); + + // Alice's balance decreased by exactly amountInUsed + assertEq(token0.balanceOf(alice), balAliceBefore - amountInUsed); + + // Receiver (bob) gained amountOut of token1 + assertEq(token1.balanceOf(bob), balReceiverBefore + amountOut); + + // Pool's token0 balance increased by amountInUsed + assertEq(token0.balanceOf(address(pool)), balPoolBefore + amountInUsed); + } + + /// @notice swap with limitPrice <= current price should bubble up the LMSR revert. + function testSwapLimitPriceRevert() public { + // Current marginal price for balanced pool is ~1: set limitPrice == 1 to trigger LMSR revert + int128 limitPrice = ABDKMath64x64.fromInt(1); + + vm.prank(alice); + token0.approve(address(pool), type(uint256).max); + + vm.prank(alice); + vm.expectRevert(bytes("LMSR: limitPrice <= current price")); + pool.swap(alice, alice, 0, 1, 1000, limitPrice, 0); + } + + /// @notice swapToLimit should compute input needed to reach a slightly higher price and execute. + function testSwapToLimit() public { + // Choose a limit price slightly above current (~1) + int128 limitPrice = ABDKMath64x64.fromInt(1).add(ABDKMath64x64.divu(1, 1000)); + + vm.prank(alice); + token0.approve(address(pool), type(uint256).max); + + vm.prank(alice); + (uint256 amountInUsed, uint256 amountOut) = pool.swapToLimit(alice, bob, 0, 1, limitPrice, 0); + + assertTrue(amountInUsed > 0, "expected some input used for swapToLimit"); + assertTrue(amountOut > 0, "expected some output for swapToLimit"); + + // Verify bob got the output + assertEq(token1.balanceOf(bob) >= amountOut, true); + } + + /// @notice Gas measurement: perform 100 swaps back-and-forth between token0 and token1. + function testSwapGas3() public { + // Ensure alice approves pool for both tokens + vm.prank(alice); + token0.approve(address(pool), type(uint256).max); + vm.prank(alice); + token1.approve(address(pool), type(uint256).max); + + uint256 maxIn = 1_000; + + // Perform 100 swaps alternating directions to avoid large imbalance + for (uint256 i = 0; i < 100; i++) { + vm.prank(alice); + if (i % 2 == 0) { + // swap token0 -> token1 + pool.swap(alice, alice, 0, 1, maxIn, 0, 0); + } else { + // swap token1 -> token0 + pool.swap(alice, alice, 1, 0, maxIn, 0, 0); + } + } + } + + /// @notice Gas measurement: perform 100 swaps back-and-forth between token0 and token1 in the 10-token pool. + function testSwapGas10() public { + // Ensure alice approves pool10 for both tokens + vm.prank(alice); + token0.approve(address(pool10), type(uint256).max); + vm.prank(alice); + token1.approve(address(pool10), type(uint256).max); + + uint256 maxIn = 1_000; + + // Perform 100 swaps alternating directions to avoid large imbalance + for (uint256 i = 0; i < 100; i++) { + vm.prank(alice); + if (i % 2 == 0) { + // swap token0 -> token1 + pool10.swap(alice, alice, 0, 1, maxIn, 0, 0); + } else { + // swap token1 -> token0 + pool10.swap(alice, alice, 1, 0, maxIn, 0, 0); + } + } + } + + /// @notice Gas-style test: alternate swapMint then burnSwap on the 3-token pool to keep pool size roughly stable. + function testSwapMintBurnSwapGas3() public { + uint256 iterations = 100; + uint256 input = 1_000; + + // Top up alice so repeated operations won't fail + token0.mint(alice, iterations * input * 2); + + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + + for (uint256 k = 0; k < iterations; k++) { + // Mint LP by providing single-token input; receive LP minted + uint256 minted = pool.swapMint(alice, alice, 0, input, 0); + // If nothing minted (numerical edge), skip burn step + if (minted == 0) continue; + // Immediately burn the minted LP back to tokens, targeting the same token index + pool.burnSwap(alice, alice, minted, 0, 0); + } + + vm.stopPrank(); + } + + /// @notice Gas-style test: alternate swapMint then burnSwap on the 10-token pool to keep pool size roughly stable. + function testSwapMintBurnSwapGas10() public { + uint256 iterations = 100; + uint256 input = 1_000; + + // Top up alice so repeated operations won't fail + token0.mint(alice, iterations * input * 2); + + vm.startPrank(alice); + token0.approve(address(pool10), type(uint256).max); + + for (uint256 k = 0; k < iterations; k++) { + uint256 minted = pool10.swapMint(alice, alice, 0, input, 0); + if (minted == 0) continue; + pool10.burnSwap(alice, alice, minted, 0, 0); + } + + vm.stopPrank(); + } + + /// @notice Combined gas test (mint then burn) on 3-token pool using mint() and burn(). + /// Alternates minting a tiny LP amount and immediately burning the actual minted LP back to avoid net pool depletion. + function testMintBurnGas3() public { + uint256 iterations = 50; + uint256 input = 1_000; + + // Ensure alice has enough tokens for all mints + token0.mint(alice, iterations * input * 2); + token1.mint(alice, iterations * input * 2); + token2.mint(alice, iterations * input * 2); + + vm.startPrank(alice); + // Approve pool to transfer tokens for proportional mint + token0.approve(address(pool), type(uint256).max); + token1.approve(address(pool), type(uint256).max); + token2.approve(address(pool), type(uint256).max); + + for (uint256 k = 0; k < iterations; k++) { + // Request a tiny LP mint (1 wei) - pool will compute deposits and transfer from alice + uint256 lpRequest = 1; + + // Snapshot alice LP before to compute actual minted + uint256 lpBefore = pool.balanceOf(alice); + + // Perform mint; this will transfer underlying from alice into pool + pool.mint(alice, alice, lpRequest, 0); + + uint256 lpAfter = pool.balanceOf(alice); + uint256 actualMinted = lpAfter - lpBefore; + + // If nothing minted due to rounding edge, skip burn + if (actualMinted == 0) { + continue; + } + + // Burn via plain burn() which will transfer underlying back to alice and burn LP + pool.burn(alice, alice, actualMinted, 0); + } + + vm.stopPrank(); + } + + /// @notice Verify computeMintAmounts matches the actual token transfers performed by mint() + function testComputeMintAmountsMatchesMint_3TokenPool() public { + // Use a range of LP requests (tiny to large fraction) + uint256 totalLp = pool.totalSupply(); + uint256[] memory requests = new uint256[](4); + requests[0] = 1; + requests[1] = totalLp / 100; // 1% + requests[2] = totalLp / 10; // 10% + requests[3] = totalLp / 2; // 50% + for (uint k = 0; k < requests.length; k++) { + uint256 req = requests[k]; + if (req == 0) req = 1; + + // Compute expected deposit amounts via view + uint256[] memory expected = pool.computeMintAmounts(req); + + // Ensure alice has tokens and approve pool + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + token1.approve(address(pool), type(uint256).max); + token2.approve(address(pool), type(uint256).max); + + // Snapshot alice balances before mint + uint256 a0Before = token0.balanceOf(alice); + uint256 a1Before = token1.balanceOf(alice); + uint256 a2Before = token2.balanceOf(alice); + + // Perform mint (may revert for zero-request; ensure req>0 above) + // Guard: if computeMintAmounts returned all zeros, skip (nothing to transfer) + bool allZero = (expected[0] == 0 && expected[1] == 0 && expected[2] == 0); + if (!allZero) { + uint256 lpBefore = pool.balanceOf(alice); + pool.mint(alice, alice, req, 0); + uint256 lpAfter = pool.balanceOf(alice); + // Confirm some LP minted (or at least not negative) + assertTrue(lpAfter >= lpBefore, "LP minted should not decrease"); + + // Check actual spent equals expected deposit amounts + assertEq(a0Before - token0.balanceOf(alice), expected[0], "token0 spent mismatch"); + assertEq(a1Before - token1.balanceOf(alice), expected[1], "token1 spent mismatch"); + assertEq(a2Before - token2.balanceOf(alice), expected[2], "token2 spent mismatch"); + } + + vm.stopPrank(); + } + } + + /// @notice Verify computeMintAmounts matches the actual token transfers performed by mint() for 10-token pool + function testComputeMintAmountsMatchesMint_10TokenPool() public { + uint256 totalLp = pool10.totalSupply(); + uint256[] memory requests = new uint256[](4); + requests[0] = 1; + requests[1] = totalLp / 100; + requests[2] = totalLp / 10; + requests[3] = totalLp / 2; + for (uint k = 0; k < requests.length; k++) { + uint256 req = requests[k]; + if (req == 0) req = 1; + + uint256[] memory expected = pool10.computeMintAmounts(req); + + // Approve all tokens from alice + vm.startPrank(alice); + token0.approve(address(pool10), type(uint256).max); + token1.approve(address(pool10), type(uint256).max); + token2.approve(address(pool10), type(uint256).max); + token3.approve(address(pool10), type(uint256).max); + token4.approve(address(pool10), type(uint256).max); + token5.approve(address(pool10), type(uint256).max); + token6.approve(address(pool10), type(uint256).max); + token7.approve(address(pool10), type(uint256).max); + token8.approve(address(pool10), type(uint256).max); + token9.approve(address(pool10), type(uint256).max); + + // Snapshot alice balances before + uint256[] memory beforeBal = new uint256[](10); + beforeBal[0] = token0.balanceOf(alice); + beforeBal[1] = token1.balanceOf(alice); + beforeBal[2] = token2.balanceOf(alice); + beforeBal[3] = token3.balanceOf(alice); + beforeBal[4] = token4.balanceOf(alice); + beforeBal[5] = token5.balanceOf(alice); + beforeBal[6] = token6.balanceOf(alice); + beforeBal[7] = token7.balanceOf(alice); + beforeBal[8] = token8.balanceOf(alice); + beforeBal[9] = token9.balanceOf(alice); + + bool allZero = true; + for (uint i = 0; i < 10; i++) { if (expected[i] != 0) { allZero = false; break; } } + + if (!allZero) { + pool10.mint(alice, alice, req, 0); + + // Verify each token spent equals expected + assertEq(beforeBal[0] - token0.balanceOf(alice), expected[0], "t0 spent mismatch"); + assertEq(beforeBal[1] - token1.balanceOf(alice), expected[1], "t1 spent mismatch"); + assertEq(beforeBal[2] - token2.balanceOf(alice), expected[2], "t2 spent mismatch"); + assertEq(beforeBal[3] - token3.balanceOf(alice), expected[3], "t3 spent mismatch"); + assertEq(beforeBal[4] - token4.balanceOf(alice), expected[4], "t4 spent mismatch"); + assertEq(beforeBal[5] - token5.balanceOf(alice), expected[5], "t5 spent mismatch"); + assertEq(beforeBal[6] - token6.balanceOf(alice), expected[6], "t6 spent mismatch"); + assertEq(beforeBal[7] - token7.balanceOf(alice), expected[7], "t7 spent mismatch"); + assertEq(beforeBal[8] - token8.balanceOf(alice), expected[8], "t8 spent mismatch"); + assertEq(beforeBal[9] - token9.balanceOf(alice), expected[9], "t9 spent mismatch"); + } + + vm.stopPrank(); + } + } + + /// @notice Verify computeBurnAmounts matches actual transfers performed by burn() for 3-token pool + function testComputeBurnAmountsMatchesBurn_3TokenPool() public { + // Use address(this) as payer (holds initial LP from setUp) + uint256 totalLp = pool.totalSupply(); + uint256[] memory burns = new uint256[](4); + burns[0] = 1; + burns[1] = totalLp / 100; + burns[2] = totalLp / 10; + burns[3] = totalLp / 2; + for (uint k = 0; k < burns.length; k++) { + uint256 req = burns[k]; + if (req == 0) req = 1; + + // Ensure this contract has enough LP to cover the requested burn; top up from alice if needed + uint256 myLp = pool.balanceOf(address(this)); + if (myLp < req) { + uint256 topUp = req - myLp; + // Have alice supply tokens to mint LP into this contract + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + token1.approve(address(pool), type(uint256).max); + token2.approve(address(pool), type(uint256).max); + pool.mint(alice, address(this), topUp, 0); + vm.stopPrank(); + } + + // Recompute withdraw amounts via view after any top-up + uint256[] memory expected = pool.computeBurnAmounts(req); + + // If expected withdraws are all zero (rounding edge), skip this iteration + if (expected[0] == 0 && expected[1] == 0 && expected[2] == 0) { + continue; + } + + // Snapshot bob balances before + uint256 b0Before = token0.balanceOf(bob); + uint256 b1Before = token1.balanceOf(bob); + uint256 b2Before = token2.balanceOf(bob); + + // Perform burn using the computed LP amount (proportional withdrawal) + pool.burn(address(this), bob, req, 0); + + // Verify bob received exactly the expected amounts + assertEq(token0.balanceOf(bob) - b0Before, expected[0], "token0 withdraw mismatch"); + assertEq(token1.balanceOf(bob) - b1Before, expected[1], "token1 withdraw mismatch"); + assertEq(token2.balanceOf(bob) - b2Before, expected[2], "token2 withdraw mismatch"); + + // totalSupply must not increase + assertTrue(pool.totalSupply() <= totalLp, "totalSupply should not increase after burn"); + totalLp = pool.totalSupply(); // update for next iteration + } + } + + /// @notice Verify computeBurnAmounts matches actual transfers performed by burn() for 10-token pool + function testComputeBurnAmountsMatchesBurn_10TokenPool() public { + uint256 totalLp = pool10.totalSupply(); + uint256[] memory burns = new uint256[](4); + burns[0] = 1; + burns[1] = totalLp / 100; + burns[2] = totalLp / 10; + burns[3] = totalLp / 2; + for (uint k = 0; k < burns.length; k++) { + uint256 req = burns[k]; + if (req == 0) req = 1; + + // Ensure this contract has enough LP to cover the requested burn; top up from alice if needed + uint256 myLp = pool10.balanceOf(address(this)); + if (myLp < req) { + uint256 topUp = req - myLp; + vm.startPrank(alice); + token0.approve(address(pool10), type(uint256).max); + token1.approve(address(pool10), type(uint256).max); + token2.approve(address(pool10), type(uint256).max); + token3.approve(address(pool10), type(uint256).max); + token4.approve(address(pool10), type(uint256).max); + token5.approve(address(pool10), type(uint256).max); + token6.approve(address(pool10), type(uint256).max); + token7.approve(address(pool10), type(uint256).max); + token8.approve(address(pool10), type(uint256).max); + token9.approve(address(pool10), type(uint256).max); + pool10.mint(alice, address(this), topUp, 0); + vm.stopPrank(); + } + + uint256[] memory expected = pool10.computeBurnAmounts(req); + + // If expected withdraws are all zero (rounding edge), skip this iteration + bool allZero = true; + for (uint i = 0; i < 10; i++) { if (expected[i] != 0) { allZero = false; break; } } + if (allZero) { continue; } + + // Snapshot bob balances + uint256[] memory beforeBal = new uint256[](10); + beforeBal[0] = token0.balanceOf(bob); + beforeBal[1] = token1.balanceOf(bob); + beforeBal[2] = token2.balanceOf(bob); + beforeBal[3] = token3.balanceOf(bob); + beforeBal[4] = token4.balanceOf(bob); + beforeBal[5] = token5.balanceOf(bob); + beforeBal[6] = token6.balanceOf(bob); + beforeBal[7] = token7.balanceOf(bob); + beforeBal[8] = token8.balanceOf(bob); + beforeBal[9] = token9.balanceOf(bob); + + pool10.burn(address(this), bob, req, 0); + + // Verify bob received each expected amount + assertEq(token0.balanceOf(bob) - beforeBal[0], expected[0], "t0 withdraw mismatch"); + assertEq(token1.balanceOf(bob) - beforeBal[1], expected[1], "t1 withdraw mismatch"); + assertEq(token2.balanceOf(bob) - beforeBal[2], expected[2], "t2 withdraw mismatch"); + assertEq(token3.balanceOf(bob) - beforeBal[3], expected[3], "t3 withdraw mismatch"); + assertEq(token4.balanceOf(bob) - beforeBal[4], expected[4], "t4 withdraw mismatch"); + assertEq(token5.balanceOf(bob) - beforeBal[5], expected[5], "t5 withdraw mismatch"); + assertEq(token6.balanceOf(bob) - beforeBal[6], expected[6], "t6 withdraw mismatch"); + assertEq(token7.balanceOf(bob) - beforeBal[7], expected[7], "t7 withdraw mismatch"); + assertEq(token8.balanceOf(bob) - beforeBal[8], expected[8], "t8 withdraw mismatch"); + assertEq(token9.balanceOf(bob) - beforeBal[9], expected[9], "t9 withdraw mismatch"); + + assertTrue(pool10.totalSupply() <= totalLp, "totalSupply should not increase after burn"); + totalLp = pool10.totalSupply(); + } + } + + /// @notice Combined gas test (mint then burn) on 10-token pool using mint() and burn(). + /// Alternates small mints and burns to keep the pool size roughly stable. + function testMintBurnGas10() public { + uint256 iterations = 50; + uint256 input = 1_000; + + // Ensure alice has enough tokens for all mints across 10 tokens + for (uint i = 0; i < 10; i++) { + // mint to alice corresponding token; use token0..token9 mapping in setUp ordering + if (i == 0) token0.mint(alice, iterations * input * 2); + else if (i == 1) token1.mint(alice, iterations * input * 2); + else if (i == 2) token2.mint(alice, iterations * input * 2); + else if (i == 3) token3.mint(alice, iterations * input * 2); + else if (i == 4) token4.mint(alice, iterations * input * 2); + else if (i == 5) token5.mint(alice, iterations * input * 2); + else if (i == 6) token6.mint(alice, iterations * input * 2); + else if (i == 7) token7.mint(alice, iterations * input * 2); + else if (i == 8) token8.mint(alice, iterations * input * 2); + else if (i == 9) token9.mint(alice, iterations * input * 2); + } + + vm.startPrank(alice); + // Approve pool10 to transfer tokens for proportional mint + token0.approve(address(pool10), type(uint256).max); + token1.approve(address(pool10), type(uint256).max); + token2.approve(address(pool10), type(uint256).max); + token3.approve(address(pool10), type(uint256).max); + token4.approve(address(pool10), type(uint256).max); + token5.approve(address(pool10), type(uint256).max); + token6.approve(address(pool10), type(uint256).max); + token7.approve(address(pool10), type(uint256).max); + token8.approve(address(pool10), type(uint256).max); + token9.approve(address(pool10), type(uint256).max); + + for (uint256 k = 0; k < iterations; k++) { + uint256 lpRequest = 1; + + uint256 lpBefore = pool10.balanceOf(alice); + pool10.mint(alice, alice, lpRequest, 0); + uint256 lpAfter = pool10.balanceOf(alice); + uint256 actualMinted = lpAfter - lpBefore; + + if (actualMinted == 0) continue; + + pool10.burn(alice, alice, actualMinted, 0); + } + + vm.stopPrank(); + } + + /// @notice Basic test for swapMint: single-token deposit -> LP minted + function testSwapMintBasic() public { + // alice must approve pool to transfer token0 + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + + uint256 aliceBalBefore = token0.balanceOf(alice); + uint256 aliceLpBefore = pool.balanceOf(alice); + + uint256 input = 10_000; + // Call swapMint as alice, receive LP to alice + uint256 minted = pool.swapMint(alice, alice, 0, input, 0); + + // minted should be > 0 + assertTrue(minted > 0, "swapMint should mint LP"); + + // Alice token balance must have decreased by at most input (fee included) + uint256 aliceBalAfter = token0.balanceOf(alice); + assertTrue(aliceBalAfter <= aliceBalBefore, "alice token balance should not increase"); + assertTrue(aliceBalBefore - aliceBalAfter <= input, "alice spent more than provided"); + + // Alice LP balance increased by minted + uint256 aliceLpAfter = pool.balanceOf(alice); + assertTrue(aliceLpAfter >= aliceLpBefore + minted, "alice should receive minted LP"); + + vm.stopPrank(); + } + + /// @notice Large input to swapMint should not over-consume: consumed <= provided + function testSwapMintLargeInputPartial() public { + // Very large input relative to pool + uint256 largeInput = 10_000_000_000; // intentionally large + + // Ensure alice has sufficient tokens for this large test input (mint top-up) + token0.mint(alice, largeInput); + + vm.startPrank(alice); + token0.approve(address(pool), type(uint256).max); + + uint256 aliceBalBefore = token0.balanceOf(alice); + + uint256 minted = pool.swapMint(alice, alice, 0, largeInput, 0); + + // minted should be > 0 + assertTrue(minted > 0, "swapMint large input should still mint LP"); + + uint256 aliceBalAfter = token0.balanceOf(alice); + uint256 spent = aliceBalBefore - aliceBalAfter; + + // Spent must be <= provided largeInput + assertTrue(spent <= largeInput, "swapMint must not consume more than provided"); + + // Some consumption occurred + assertTrue(spent > 0, "swapMint should have consumed some tokens"); + + vm.stopPrank(); + } + + /// @notice Basic burnSwap test: burn LP (from this contract) and receive single-token payout to bob + function testBurnSwapBasic() public { + // Use a fraction of the pool's supply to burn + uint256 supplyBefore = pool.totalSupply(); + assertTrue(supplyBefore > 0, "precondition: supply>0"); + + uint256 lpToBurn = supplyBefore / 10; + if (lpToBurn == 0) lpToBurn = 1; + + // Choose target token index 0 + uint256 target = 0; + + // Bob's balance before + uint256 bobBefore = token0.balanceOf(bob); + + // Call burnSwap where this contract is the payer (it holds initial LP from setUp) + uint256 payout = pool.burnSwap(address(this), bob, lpToBurn, target, 0); + + // Payout must be > 0 + assertTrue(payout > 0, "burnSwap should produce a payout"); + + // Bob's balance increased by at least payout + uint256 bobAfter = token0.balanceOf(bob); + assertTrue(bobAfter >= bobBefore + payout, "Bob should receive payout tokens"); + + // Supply decreased by at least lpToBurn (burn event should have burned exactly lpToBurn) + uint256 supplyAfter = pool.totalSupply(); + assertTrue(supplyAfter <= supplyBefore - lpToBurn, "totalSupply should decrease by burned LP"); + } + + /* ---------------------- + Flash Loan Tests + ---------------------- */ + + /// @notice Setup a flash borrower for testing + function setupFlashBorrower() internal returns (FlashBorrower borrower) { + // Create array of token addresses for borrower + address[] memory tokenAddresses = new address[](3); + tokenAddresses[0] = address(token0); + tokenAddresses[1] = address(token1); + tokenAddresses[2] = address(token2); + + // Deploy the borrower contract + borrower = new FlashBorrower(address(pool), tokenAddresses); + + // Mint tokens to alice to be used for repayments + token0.mint(alice, INIT_BAL * 2); + token1.mint(alice, INIT_BAL * 2); + token2.mint(alice, INIT_BAL * 2); + + // Alice approves borrower to transfer tokens on their behalf for repayment + vm.startPrank(alice); + token0.approve(address(borrower), type(uint256).max); + token1.approve(address(borrower), type(uint256).max); + token2.approve(address(borrower), type(uint256).max); + vm.stopPrank(); + } + + /// @notice Test flash loan with a single token + function testFlashLoanSingleToken() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay normally + borrower.setAction(FlashBorrower.Action.NORMAL, alice); + + // Create loan request for token0 only + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; // Only borrow token0 + + // Record balances before flash + uint256 aliceToken0Before = token0.balanceOf(alice); + uint256 poolToken0Before = token0.balanceOf(address(pool)); + + // Execute flash loan + borrower.flash(amounts); + + // Net change for alice should equal the flash fee (principal is returned during repayment) + uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 expectedAliceDecrease = fee; + assertEq( + aliceToken0Before - token0.balanceOf(alice), + expectedAliceDecrease, + "Alice should pay flash fee" + ); + + // Check pool's balance increased by the fee + assertEq( + token0.balanceOf(address(pool)), + poolToken0Before + fee, + "Pool should receive fee" + ); + } + + /// @notice Test flash loan with multiple tokens + function testFlashLoanMultipleTokens() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay normally + borrower.setAction(FlashBorrower.Action.NORMAL, alice); + + // Create loan request for all tokens + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + amounts[1] = 2000; + amounts[2] = 3000; + + // Record balances before flash + uint256[] memory aliceBalancesBefore = new uint256[](3); + uint256[] memory poolBalancesBefore = new uint256[](3); + + aliceBalancesBefore[0] = token0.balanceOf(alice); + aliceBalancesBefore[1] = token1.balanceOf(alice); + aliceBalancesBefore[2] = token2.balanceOf(alice); + + poolBalancesBefore[0] = token0.balanceOf(address(pool)); + poolBalancesBefore[1] = token1.balanceOf(address(pool)); + poolBalancesBefore[2] = token2.balanceOf(address(pool)); + + // Execute flash loan + borrower.flash(amounts); + + // Check balances for each token + for (uint256 i = 0; i < 3; i++) { + uint256 fee = (amounts[i] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 expectedAliceDecrease = fee; + + IERC20 token; + if (i == 0) token = token0; + else if (i == 1) token = token1; + else token = token2; + + // Net change for Alice should equal the flash fee for this token (principal was returned) + assertEq( + aliceBalancesBefore[i] - token.balanceOf(alice), + expectedAliceDecrease, + "Alice should pay flash fee for token" + ); + + // Pool's balance increased by fee + assertEq( + token.balanceOf(address(pool)), + poolBalancesBefore[i] + fee, + "Pool should receive fee for token" + ); + } + } + + /// @notice Test flash loan with some zero amounts (should be skipped) + function testFlashLoanWithZeroAmounts() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay normally + borrower.setAction(FlashBorrower.Action.NORMAL, alice); + + // Create loan request with mix of zero and non-zero amounts + uint256[] memory amounts = new uint256[](3); + amounts[0] = 0; // Zero - should be skipped + amounts[1] = 2000; // Non-zero + amounts[2] = 0; // Zero - should be skipped + + // Record balances before flash + uint256 aliceToken1Before = token1.balanceOf(alice); + uint256 poolToken1Before = token1.balanceOf(address(pool)); + + // Tokens that should remain unchanged + uint256 aliceToken0Before = token0.balanceOf(alice); + uint256 aliceToken2Before = token2.balanceOf(alice); + uint256 poolToken0Before = token0.balanceOf(address(pool)); + uint256 poolToken2Before = token2.balanceOf(address(pool)); + + // Execute flash loan + borrower.flash(amounts); + + // Check token1 balances changed appropriately + uint256 fee = (amounts[1] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 expectedAliceDecrease = fee; + + assertEq( + aliceToken1Before - token1.balanceOf(alice), + expectedAliceDecrease, + "Alice should pay flash fee for token1" + ); + + assertEq( + token1.balanceOf(address(pool)), + poolToken1Before + fee, + "Pool should receive fee for token1" + ); + + // Check token0 and token2 balances remained unchanged + assertEq(token0.balanceOf(alice), aliceToken0Before, "Alice token0 balance should be unchanged"); + assertEq(token2.balanceOf(alice), aliceToken2Before, "Alice token2 balance should be unchanged"); + assertEq(token0.balanceOf(address(pool)), poolToken0Before, "Pool token0 balance should be unchanged"); + assertEq(token2.balanceOf(address(pool)), poolToken2Before, "Pool token2 balance should be unchanged"); + } + + /// @notice Test that flash reverts when all amounts are zero + function testFlashLoanAllZeroAmountsReverts() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay normally + borrower.setAction(FlashBorrower.Action.NORMAL, alice); + + // Create loan request with all zeros + uint256[] memory amounts = new uint256[](3); + amounts[0] = 0; + amounts[1] = 0; + amounts[2] = 0; + + // Execute flash loan - should revert + vm.expectRevert(bytes("flash: no tokens requested")); + borrower.flash(amounts); + } + + /// @notice Test flash loan with incorrect repayment (none) + function testFlashLoanNoRepaymentReverts() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to not repay anything + borrower.setAction(FlashBorrower.Action.REPAY_NONE, alice); + + // Create loan request + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Execute flash loan - should revert on validation + vm.expectRevert(bytes("flash: repayment failed")); + borrower.flash(amounts); + } + + /// @notice Test flash loan with partial repayment (should revert) + function testFlashLoanPartialRepaymentReverts() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay only half the required amount + borrower.setAction(FlashBorrower.Action.REPAY_PARTIAL, alice); + + // Create loan request + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Execute flash loan - should revert on validation + vm.expectRevert(bytes("flash: repayment failed")); + borrower.flash(amounts); + } + + /// @notice Test flash loan with principal repayment but no fee (should revert) + function testFlashLoanNoFeeRepaymentReverts() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay only the principal without fee + borrower.setAction(FlashBorrower.Action.REPAY_NO_FEE, alice); + + // Create loan request + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Execute flash loan - should revert on validation if fee > 0 + if (pool.flashFeePpm() > 0) { + vm.expectRevert(bytes("flash: repayment failed")); + borrower.flash(amounts); + } else { + // If fee is zero, this should succeed + borrower.flash(amounts); + } + } + + /// @notice Test flash loan with exact repayment (should succeed) + function testFlashLoanExactRepayment() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay exactly the required amount + borrower.setAction(FlashBorrower.Action.REPAY_EXACT, alice); + + // Create loan request + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Record balances before flash + uint256 aliceToken0Before = token0.balanceOf(alice); + uint256 poolToken0Before = token0.balanceOf(address(pool)); + + // Execute flash loan + borrower.flash(amounts); + + // Check balances: net change for alice should equal the fee + uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 expectedAliceDecrease = fee; + + assertEq( + aliceToken0Before - token0.balanceOf(alice), + expectedAliceDecrease, + "Alice should pay flash fee" + ); + + assertEq( + token0.balanceOf(address(pool)), + poolToken0Before + fee, + "Pool should receive fee" + ); + } + + /// @notice Test flash loan with extra repayment (donation, should succeed) + function testFlashLoanExtraRepayment() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower to repay more than required + borrower.setAction(FlashBorrower.Action.REPAY_EXTRA, alice); + + // Create loan request + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Record balances before flash + uint256 aliceToken0Before = token0.balanceOf(alice); + uint256 poolToken0Before = token0.balanceOf(address(pool)); + + // Execute flash loan + borrower.flash(amounts); + + // Check balances - net change for alice should equal fee + extra donation (principal returned) + uint256 fee = (amounts[0] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceil fee calculation + uint256 extra = 1; // borrower donates +1 per token in REPAY_EXTRA + uint256 expectedAliceDecrease = fee + extra; // fee plus donation + + assertEq( + aliceToken0Before - token0.balanceOf(alice), + expectedAliceDecrease, + "Alice should pay fee + extra" + ); + + assertEq( + token0.balanceOf(address(pool)), + poolToken0Before + fee + extra, + "Pool should receive fee + extra" + ); + } + + /// @notice Test computeFlashRepaymentAmounts matches flash implementation + function testComputeFlashRepaymentAmounts() public view { + // Create different loan amount scenarios + uint256[][] memory testCases = new uint256[][](3); + + // Case 1: Single token + testCases[0] = new uint256[](3); + testCases[0][0] = 1000; + testCases[0][1] = 0; + testCases[0][2] = 0; + + // Case 2: Multiple tokens + testCases[1] = new uint256[](3); + testCases[1][0] = 1000; + testCases[1][1] = 2000; + testCases[1][2] = 3000; + + // Case 3: Mix of zero and non-zero + testCases[2] = new uint256[](3); + testCases[2][0] = 0; + testCases[2][1] = 2000; + testCases[2][2] = 0; + + for (uint256 i = 0; i < testCases.length; i++) { + uint256[] memory loanAmounts = testCases[i]; + uint256[] memory repaymentAmounts = pool.computeFlashRepaymentAmounts(loanAmounts); + + // Verify each repayment amount is correctly calculated + for (uint256 j = 0; j < loanAmounts.length; j++) { + if (loanAmounts[j] == 0) { + // Zero loans should have zero repayment + assertEq(repaymentAmounts[j], 0, "Zero loan should have zero repayment"); + } else { + // Calculate expected repayment with fee + uint256 fee = (loanAmounts[j] * pool.flashFeePpm() + 1_000_000 - 1) / 1_000_000; // ceiling + uint256 expectedRepayment = loanAmounts[j] + fee; + + assertEq( + repaymentAmounts[j], + expectedRepayment, + "Repayment calculation mismatch" + ); + } + } + } + } + + /// @notice Test flash with invalid recipient + function testFlashWithZeroRecipientReverts() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower with zero recipient + borrower.setAction(FlashBorrower.Action.NORMAL, address(0)); + + // Create loan request + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Execute flash loan - should revert due to zero recipient + vm.expectRevert(bytes("flash: zero recipient")); + borrower.flash(amounts); + } + + /// @notice Test flash with incorrect amounts length + function testFlashWithIncorrectLengthReverts() public { + // Call flash directly with incorrect length + uint256[] memory wrongLengthAmounts = new uint256[](2); // Pool has 3 tokens + wrongLengthAmounts[0] = 1000; + wrongLengthAmounts[1] = 2000; + + vm.expectRevert(bytes("flash: amounts length mismatch")); + pool.flash(alice, wrongLengthAmounts, ""); + } + + /// @notice Gas measurement: flash with single token + function testFlashGasSingleToken() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower + borrower.setAction(FlashBorrower.Action.NORMAL, alice); + + // Create loan request for single token + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + + // Execute flash loan 10 times to measure gas + for (uint256 i = 0; i < 10; i++) { + borrower.flash(amounts); + } + } + + /// @notice Gas measurement: flash with multiple tokens + function testFlashGasMultipleTokens() public { + FlashBorrower borrower = setupFlashBorrower(); + + // Configure borrower + borrower.setAction(FlashBorrower.Action.NORMAL, alice); + + // Create loan request for multiple tokens + uint256[] memory amounts = new uint256[](3); + amounts[0] = 1000; + amounts[1] = 2000; + amounts[2] = 3000; + + // Execute flash loan 10 times to measure gas + for (uint256 i = 0; i < 10; i++) { + borrower.flash(amounts); + } + } +}