diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index 8dc9eb7..59c359d 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -1,4 +1,8 @@ -use std::{cmp::max, collections::HashSet, str::FromStr}; +use std::{ + cmp::max, + collections::{HashMap, HashSet, VecDeque}, + str::FromStr, +}; use alloy_primitives::{aliases::U24, FixedBytes, U256, U8}; use alloy_sol_types::SolValue; @@ -13,7 +17,7 @@ use crate::encoding::{ swap_encoder::SWAP_ENCODER_REGISTRY, utils::{biguint_to_u256, bytes_to_address, encode_input, percentage_to_uint24}, }, - models::{EncodingContext, NativeAction, Solution}, + models::{EncodingContext, NativeAction, Solution, Swap}, strategy_encoder::StrategyEncoder, }; @@ -71,6 +75,12 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { solution: Solution, router_address: Bytes, ) -> Result<(Vec, Bytes), EncodingError> { + validate_swaps(&solution.swaps)?; + validate_token_path_connectivity( + &solution.swaps, + &solution.given_token, + &solution.checked_token, + )?; let (permit, signature) = self.permit2.get_permit( &router_address, &solution.sender, @@ -203,6 +213,145 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { } } +fn validate_swaps(swaps: &[Swap]) -> Result<(), EncodingError> { + for swap in swaps { + if swap.split >= 1.0 { + return Err(EncodingError::InvalidInput(format!( + "Split percentage must be less than 1 (100%), got {}", + swap.split + ))); + } + } + + let mut swaps_by_token: HashMap> = HashMap::new(); + + for swap in swaps { + swaps_by_token + .entry(swap.token_in.clone()) + .or_default() + .push(swap); + } + + for (token, token_swaps) in swaps_by_token { + if token_swaps.is_empty() { + return Err(EncodingError::InvalidInput(format!( + "No swaps found for token {:?}", + token + ))); + } + + // Single swaps don't need remainder handling + if token_swaps.len() == 1 { + continue; + } + + // Check if exactly one swap has 0% split and it's the last one + let zero_splits: Vec<_> = token_swaps + .iter() + .enumerate() + .filter(|(_, s)| s.split == 0.0) + .collect(); + + if zero_splits.len() != 1 { + return Err(EncodingError::InvalidInput(format!( + "Token {:?} must have exactly one 0% split for remainder handling", + token + ))); + } + + if zero_splits[0].0 != token_swaps.len() - 1 { + return Err(EncodingError::InvalidInput(format!( + "The 0% split for token {:?} must be the last swap", + token + ))); + } + + // Sum non-zero splits and validate each is >0% and <100% + let mut total_percentage = 0.0; + for swap in token_swaps + .iter() + .take(token_swaps.len() - 1) + { + if swap.split <= 0.0 { + return Err(EncodingError::InvalidInput(format!( + "Non-remainder splits must be >0% for token {:?}", + token + ))); + } + total_percentage += swap.split; + } + + // Total must be <100% to leave room for remainder + if total_percentage >= 1.0 { + return Err(EncodingError::InvalidInput(format!( + "Total of non-remainder splits for token {:?} must be <100%, got {}%", + token, + total_percentage * 100.0 + ))); + } + } + + Ok(()) +} + +fn validate_token_path_connectivity( + swaps: &[Swap], + given_token: &Bytes, + checked_token: &Bytes, +) -> Result<(), EncodingError> { + // Build directed graph of token flows + let mut graph: HashMap> = HashMap::new(); + for swap in swaps { + graph + .entry(swap.token_in.clone()) + .or_default() + .insert(swap.token_out.clone()); + } + + // BFS from given_token + let mut visited = HashSet::new(); + let mut queue = VecDeque::new(); + queue.push_back(given_token.clone()); + + while let Some(token) = queue.pop_front() { + if !visited.insert(token.clone()) { + continue; + } + + if let Some(next_tokens) = graph.get(&token) { + for next_token in next_tokens { + if !visited.contains(next_token) { + queue.push_back(next_token.clone()); + } + } + } + } + + // Verify all tokens are visited + let all_tokens: HashSet<_> = graph + .keys() + .chain(graph.values().flat_map(|v| v.iter())) + .collect(); + + for token in all_tokens { + if !visited.contains(token) { + return Err(EncodingError::InvalidInput(format!( + "Token {:?} is not connected to the main path", + token + ))); + } + } + + // Verify checked_token is reachable + if !visited.contains(checked_token) { + return Err(EncodingError::InvalidInput( + "Checked token is not reachable through swap path".to_string(), + )); + } + + Ok(()) +} + /// This strategy encoder is used for solutions that are sent directly to the pool. /// Only 1 solution with 1 swap is supported. pub struct ExecutorStrategyEncoder {}