From f32210bb1f6103a1775975604415295260de9107 Mon Sep 17 00:00:00 2001 From: TAMARA LIPOWSKI Date: Thu, 13 Feb 2025 01:58:34 -0500 Subject: [PATCH] feat: (WIP) UniswapV4 encoding - To keep any knowledge of USV4 separate from regular splits, I've made a new USV4 encoding strategy that will be used only if we detect sequential USV4 swaps. - For single USV4 swaps without necessary optimizations, the regular split swap strategy can be used - No need to change the swap struct interface to take multiple swaps - this concatenation can be done at the swap strategy level. TODO: - test - deduplicate code from split strategy - UniswapV4SwapEncoder --- .../evm/strategy_encoder/strategy_encoders.rs | 393 ++++++++++++++++-- src/encoding/evm/utils.rs | 43 +- 2 files changed, 399 insertions(+), 37 deletions(-) diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index 4e8193f..356b84c 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -1,12 +1,10 @@ use std::{ - cmp::max, collections::{HashMap, HashSet, VecDeque}, str::FromStr, }; use alloy_primitives::{aliases::U24, FixedBytes, U256, U8}; use alloy_sol_types::SolValue; -use num_bigint::BigUint; use tycho_core::{keccak256, Bytes}; use crate::encoding::{ @@ -14,7 +12,10 @@ use crate::encoding::{ evm::{ approvals::permit2::Permit2, swap_encoder::swap_encoder_registry::SwapEncoderRegistry, - utils::{biguint_to_u256, bytes_to_address, encode_input, percentage_to_uint24}, + utils::{ + biguint_to_u256, bytes_to_address, encode_input, get_min_amount_for_solution, + get_token_position, percentage_to_uint24, + }, }, models::{Chain, EncodingContext, NativeAction, Solution, Swap}, strategy_encoder::StrategyEncoder, @@ -258,6 +259,357 @@ impl SplitSwapStrategyEncoder { } } +/// To be used if there are two or more UniswapV4 swaps consecutively. They can be combined as a +/// gas optimization. +#[derive(Clone)] +pub struct UniswapV4StrategyEncoder { + swap_encoder_registry: SwapEncoderRegistry, + permit2: Permit2, + selector: String, + native_address: Bytes, + wrapped_address: Bytes, +} + +impl EVMStrategyEncoder for UniswapV4StrategyEncoder {} + +impl StrategyEncoder for UniswapV4StrategyEncoder { + fn encode_strategy( + &self, + solution: Solution, + ) -> Result<(Vec, Bytes, Option), EncodingError> { + self.validate_split_percentages(&solution.swaps)?; + self.validate_swap_path( + &solution.swaps, + &solution.given_token, + &solution.checked_token, + &solution.native_action, + )?; + let (permit, signature) = self.permit2.get_permit( + &solution.router_address, + &solution.sender, + &solution.given_token, + &solution.given_amount, + )?; + let min_amount_out = get_min_amount_for_solution(solution.clone()); + + // The tokens array is composed of the given token, the checked token and all the + // intermediary tokens in between. The contract expects the tokens to be in this order. + let solution_tokens: HashSet = + vec![solution.given_token.clone(), solution.checked_token.clone()] + .into_iter() + .collect(); + + let intermediary_tokens: HashSet = solution + .swaps + .iter() + .flat_map(|swap| vec![swap.token_in.clone(), swap.token_out.clone()]) + .collect(); + let mut intermediary_tokens: Vec = intermediary_tokens + .difference(&solution_tokens) + .cloned() + .collect(); + // this is only to make the test deterministic (same index for the same token for different + // runs) + intermediary_tokens.sort(); + + let (mut unwrap, mut wrap) = (false, false); + if let Some(action) = solution.native_action.clone() { + match action { + NativeAction::Wrap => wrap = true, + NativeAction::Unwrap => unwrap = true, + } + } + + let mut tokens = Vec::with_capacity(2 + intermediary_tokens.len()); + if wrap { + tokens.push(self.wrapped_address.clone()); + } else { + tokens.push(solution.given_token.clone()); + } + tokens.extend(intermediary_tokens); + + if unwrap { + tokens.push(self.wrapped_address.clone()); + } else { + tokens.push(solution.checked_token.clone()); + } + + let mut swaps = vec![]; + + let mut previous_protocol_data: Vec = vec![]; + let mut first_usv4_in_token: Bytes = Bytes::default(); + let mut last_swap_was_usv4 = false; + + for swap in solution.swaps.iter() { + let swap_encoder = self + .get_swap_encoder(&swap.component.protocol_system) + .ok_or_else(|| { + EncodingError::InvalidInput(format!( + "Swap encoder not found for protocol: {}", + swap.component.protocol_system + )) + })?; + + let current_swap_is_usv4 = swap.component.protocol_system == "uniswap_v4"; + let encoding_context = EncodingContext { + receiver: solution.router_address.clone(), + exact_out: solution.exact_out, + router_address: solution.router_address.clone(), + }; + let mut protocol_data = swap_encoder.encode_swap(swap.clone(), encoding_context)?; + let in_token; + + if current_swap_is_usv4 { + if !last_swap_was_usv4 { + // This is the first usv4 swap of a potential sequence. Store the input token + first_usv4_in_token = swap.clone().token_in; + } else { + // This is the second or later usv4 swap of a sequence. Concatenate the protocol + // data with the previous swap's protocol data + protocol_data = + [previous_protocol_data.clone(), protocol_data.clone()].concat(); + } + in_token = first_usv4_in_token.clone(); + previous_protocol_data = protocol_data.clone(); + } else { + in_token = swap.clone().token_in; + // This is not a USV4 swap. Clear previous USV4 protocol data. + previous_protocol_data = vec![]; + } + + // This is the hardest part - we will need to have the input token be the first of the + // USV4 sequence, and the output token be the last, essentially removing + // intermediate tokens and pretending they don't exist... I think? + let swap_data = self.encode_swap_header( + get_token_position(tokens.clone(), in_token)?, + get_token_position(tokens.clone(), swap.clone().token_out)?, + percentage_to_uint24(swap.split), + Bytes::from_str(swap_encoder.executor_address()).map_err(|_| { + EncodingError::FatalError("Invalid executor address".to_string()) + })?, + self.encode_executor_selector(swap_encoder.executor_selector()), + protocol_data, + ); + + // If the last swap was usv4, and this swap is also usv4, replace the last swap_data + // with the updated swap_data, which will contain both swaps, along with the + // proper input and output tokens + if last_swap_was_usv4 && current_swap_is_usv4 { + let swaps_len = swaps.len() - 1; + swaps[swaps_len] = swap_data; + } else { + swaps.push(swap_data); + } + last_swap_was_usv4 = current_swap_is_usv4; + } + + let encoded_swaps = self.ple_encode(swaps); + let method_calldata = ( + biguint_to_u256(&solution.given_amount), + bytes_to_address(&solution.given_token)?, + bytes_to_address(&solution.checked_token)?, + biguint_to_u256(&min_amount_out), + wrap, + unwrap, + U256::from(tokens.len()), + bytes_to_address(&solution.receiver)?, + permit, + signature.as_bytes().to_vec(), + encoded_swaps, + ) + .abi_encode(); + + let contract_interaction = encode_input(&self.selector, method_calldata); + Ok((contract_interaction, solution.router_address, None)) + } + + fn get_swap_encoder(&self, protocol_system: &str) -> Option<&Box> { + self.swap_encoder_registry + .get_encoder(protocol_system) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +impl UniswapV4StrategyEncoder { + #[allow(dead_code)] + pub fn new( + signer_pk: String, + chain: Chain, + swap_encoder_registry: SwapEncoderRegistry, + ) -> Result { + let selector = "swap(uint256,address,address,uint256,bool,bool,uint256,address,((address,uint160,uint48,uint48),address,uint256),bytes,bytes)".to_string(); + Ok(Self { + permit2: Permit2::new(signer_pk, chain.clone())?, + selector, + swap_encoder_registry, + native_address: chain.native_token()?, + wrapped_address: chain.wrapped_token()?, + }) + } + + /// Raises an error if the split percentages are invalid. + /// + /// Split percentages are considered valid if all the following conditions are met: + /// * Each split amount is < 1 (100%) + /// * There is exactly one 0% split for each token, and it's the last swap specified, signifying + /// to the router to send the remainder of the token to the designated protocol + /// * The sum of all non-remainder splits for each token is < 1 (100%) + /// * There are no negative split amounts + fn validate_split_percentages(&self, swaps: &[Swap]) -> Result<(), EncodingError> { + let mut swaps_by_token: HashMap> = HashMap::new(); + for swap in swaps { + if swap.split >= 1.0 { + return Err(EncodingError::InvalidInput(format!( + "Split percentage must be less than 1 (100%), got {}", + swap.split + ))); + } + swaps_by_token + .entry(swap.token_in.clone()) + .or_default() + .push(swap); + } + + for (token, token_swaps) in swaps_by_token { + // Single swaps don't need remainder handling + if token_swaps.len() == 1 { + if token_swaps[0].split != 0.0 { + return Err(EncodingError::InvalidInput(format!( + "Single swap must have 0% split for token {:?}", + token + ))); + } + continue; + } + + let mut found_zero_split = false; + let mut total_percentage = 0.0; + for (i, swap) in token_swaps.iter().enumerate() { + match (swap.split == 0.0, i == token_swaps.len() - 1) { + (true, false) => { + return Err(EncodingError::InvalidInput(format!( + "The 0% split for token {:?} must be the last swap", + token + ))) + } + (true, true) => found_zero_split = true, + (false, _) => { + if swap.split < 0.0 { + return Err(EncodingError::InvalidInput(format!( + "All splits must be >= 0% for token {:?}", + token + ))); + } + total_percentage += swap.split; + } + } + } + + if !found_zero_split { + return Err(EncodingError::InvalidInput(format!( + "Token {:?} must have exactly one 0% split for remainder handling", + token + ))); + } + + // Total must be <100% to leave room for remainder + if total_percentage >= 1.0 { + return Err(EncodingError::InvalidInput(format!( + "Total of non-remainder splits for token {:?} must be <100%, got {}%", + token, + total_percentage * 100.0 + ))); + } + } + + Ok(()) + } + + /// Raises an error if swaps do not represent a valid path from the given token to the checked + /// token. + /// + /// A path is considered valid if all the following conditions are met: + /// * The checked token is reachable from the given token through the swap path + /// * There are no tokens which are unconnected from the main path + /// + /// If the given token is the native token and the native action is WRAP, it will be converted + /// to the wrapped token before validating the swap path. The same principle applies for the + /// checked token and the UNWRAP action. + fn validate_swap_path( + &self, + swaps: &[Swap], + given_token: &Bytes, + checked_token: &Bytes, + native_action: &Option, + ) -> Result<(), EncodingError> { + // Convert ETH to WETH only if there's a corresponding wrap/unwrap action + let given_token = if *given_token == *self.native_address { + match native_action { + Some(NativeAction::Wrap) => &self.wrapped_address, + _ => given_token, + } + } else { + given_token + }; + + let checked_token = if *checked_token == *self.native_address { + match native_action { + Some(NativeAction::Unwrap) => &self.wrapped_address, + _ => checked_token, + } + } else { + checked_token + }; + + // Build directed graph of token flows + let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new(); + for swap in swaps { + graph + .entry(&swap.token_in) + .or_default() + .insert(&swap.token_out); + } + + // BFS from validation_given + let mut visited = HashSet::new(); + let mut queue = VecDeque::new(); + queue.push_back(given_token); + + while let Some(token) = queue.pop_front() { + if !visited.insert(token) { + continue; + } + + // Early success check + if token == checked_token && visited.len() == graph.len() + 1 { + return Ok(()); + } + + if let Some(next_tokens) = graph.get(token) { + for &next_token in next_tokens { + if !visited.contains(next_token) { + queue.push_back(next_token); + } + } + } + } + + // If we get here, either checked_token wasn't reached or not all tokens were visited + if !visited.contains(checked_token) { + Err(EncodingError::InvalidInput( + "Checked token is not reachable through swap path".to_string(), + )) + } else { + Err(EncodingError::InvalidInput( + "Some tokens are not connected to the main path".to_string(), + )) + } + } +} + impl EVMStrategyEncoder for SplitSwapStrategyEncoder {} impl StrategyEncoder for SplitSwapStrategyEncoder { @@ -278,19 +630,8 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { &solution.given_token, &solution.given_amount, )?; - let mut min_amount_out = solution - .checked_amount - .unwrap_or(BigUint::ZERO); + let min_amount_out = get_min_amount_for_solution(solution.clone()); - if let (Some(expected_amount), Some(slippage)) = - (solution.expected_amount.as_ref(), solution.slippage) - { - let one_hundred = BigUint::from(100u32); - let slippage_percent = BigUint::from((slippage * 100.0) as u32); - let multiplier = &one_hundred - slippage_percent; - let expected_amount_with_slippage = (expected_amount * &multiplier) / &one_hundred; - min_amount_out = max(min_amount_out, expected_amount_with_slippage); - } // The tokens array is composed of the given token, the checked token and all the // intermediary tokens in between. The contract expects the tokens to be in this order. let solution_tokens: HashSet = @@ -351,26 +692,8 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { }; let protocol_data = swap_encoder.encode_swap(swap.clone(), encoding_context)?; let swap_data = self.encode_swap_header( - U8::from( - tokens - .iter() - .position(|t| *t == swap.token_in) - .ok_or_else(|| { - EncodingError::InvalidInput( - "In token not found in tokens array".to_string(), - ) - })?, - ), - U8::from( - tokens - .iter() - .position(|t| *t == swap.token_out) - .ok_or_else(|| { - EncodingError::InvalidInput( - "Out token not found in tokens array".to_string(), - ) - })?, - ), + get_token_position(tokens.clone(), swap.token_in.clone())?, + get_token_position(tokens.clone(), swap.token_out.clone())?, percentage_to_uint24(swap.split), Bytes::from_str(swap_encoder.executor_address()).map_err(|_| { EncodingError::FatalError("Invalid executor address".to_string()) diff --git a/src/encoding/evm/utils.rs b/src/encoding/evm/utils.rs index 5d3a70a..d4017f6 100644 --- a/src/encoding/evm/utils.rs +++ b/src/encoding/evm/utils.rs @@ -1,8 +1,10 @@ -use alloy_primitives::{aliases::U24, Address, Keccak256, U256}; +use std::cmp::max; + +use alloy_primitives::{aliases::U24, Address, Keccak256, U256, U8}; use num_bigint::BigUint; use tycho_core::Bytes; -use crate::encoding::errors::EncodingError; +use crate::encoding::{errors::EncodingError, models::Solution}; /// Safely converts a `Bytes` object to an `Address` object. /// @@ -52,3 +54,40 @@ pub fn percentage_to_uint24(decimal: f64) -> U24 { let scaled = (decimal / 1.0) * (MAX_UINT24 as f64); U24::from(scaled.round()) } + +/// Gets the minimum amount out for a solution to pass when executed on-chain. +/// +/// The minimum amount is calculated based on the expected amount and the slippage percentage, if +/// passed. If this information is not passed, the user-passed checked amount will be used. +/// If both the slippage and minimum user-passed checked amount are passed, the maximum of the two +/// will be used. +/// If neither are passed, the minimum amount will be zero. +pub fn get_min_amount_for_solution(solution: Solution) -> BigUint { + let mut min_amount_out = solution + .checked_amount + .unwrap_or(BigUint::ZERO); + + if let (Some(expected_amount), Some(slippage)) = + (solution.expected_amount.as_ref(), solution.slippage) + { + let one_hundred = BigUint::from(100u32); + let slippage_percent = BigUint::from((slippage * 100.0) as u32); + let multiplier = &one_hundred - slippage_percent; + let expected_amount_with_slippage = (expected_amount * &multiplier) / &one_hundred; + min_amount_out = max(min_amount_out, expected_amount_with_slippage); + } + min_amount_out +} + +/// Gets the position of a token in a list of tokens. +pub fn get_token_position(tokens: Vec, token: Bytes) -> Result { + let position = U8::from( + tokens + .iter() + .position(|t| *t == token) + .ok_or_else(|| { + EncodingError::InvalidInput(format!("Token {:?} not found in tokens array", token)) + })?, + ); + Ok(position) +}