diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index 0de3fb8..070b166 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -67,6 +67,145 @@ impl SplitSwapStrategyEncoder { let selector = "swap(uint256,address,address,uint256,bool,bool,uint256,address,((address,uint160,uint48,uint48),address,uint256),bytes,bytes)".to_string(); Ok(Self { permit2: Permit2::new(signer_pk, chain)?, selector }) } + + fn validate_swaps(&self, swaps: &[Swap]) -> Result<(), EncodingError> { + for swap in swaps { + if swap.split >= 1.0 { + return Err(EncodingError::InvalidInput(format!( + "Split percentage must be less than 1 (100%), got {}", + swap.split + ))); + } + } + + let mut swaps_by_token: HashMap> = HashMap::new(); + + for swap in swaps { + swaps_by_token + .entry(swap.token_in.clone()) + .or_default() + .push(swap); + } + + for (token, token_swaps) in swaps_by_token { + if token_swaps.is_empty() { + return Err(EncodingError::InvalidInput(format!( + "No swaps found for token {:?}", + token + ))); + } + + // Single swaps don't need remainder handling + if token_swaps.len() == 1 { + if token_swaps[0].split != 0.0 { + return Err(EncodingError::InvalidInput(format!( + "Single swap must have 0% split for token {:?}", + token + ))); + } + continue; + } + + // Check if exactly one swap has 0% split and it's the last one + let mut found_zero_split = false; + for (i, swap) in token_swaps.iter().enumerate() { + match (swap.split == 0.0, i == token_swaps.len() - 1) { + (true, false) => { + return Err(EncodingError::InvalidInput(format!( + "The 0% split for token {:?} must be the last swap", + token + ))) + } + (true, true) => found_zero_split = true, + (false, _) => (), + } + } + + if !found_zero_split { + return Err(EncodingError::InvalidInput(format!( + "Token {:?} must have exactly one 0% split for remainder handling", + token + ))); + } + + // Sum non-zero splits and validate each is >0% and <100% + let mut total_percentage = 0.0; + for swap in token_swaps + .iter() + .take(token_swaps.len() - 1) + { + if swap.split <= 0.0 { + return Err(EncodingError::InvalidInput(format!( + "Non-remainder splits must be >0% for token {:?}", + token + ))); + } + total_percentage += swap.split; + } + + // Total must be <100% to leave room for remainder + if total_percentage >= 1.0 { + return Err(EncodingError::InvalidInput(format!( + "Total of non-remainder splits for token {:?} must be <100%, got {}%", + token, + total_percentage * 100.0 + ))); + } + } + + Ok(()) + } + + fn validate_token_path_connectivity( + &self, + swaps: &[Swap], + given_token: &Bytes, + checked_token: &Bytes, + ) -> Result<(), EncodingError> { + // Build directed graph of token flows + let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new(); + for swap in swaps { + graph + .entry(&swap.token_in) + .or_default() + .insert(&swap.token_out); + } + + // BFS from given_token + let mut visited = HashSet::new(); + let mut queue = VecDeque::new(); + queue.push_back(given_token); + + while let Some(token) = queue.pop_front() { + if !visited.insert(token) { + continue; + } + + // Early success check + if token == checked_token && visited.len() == graph.len() + 1 { + return Ok(()); + } + + if let Some(next_tokens) = graph.get(token) { + for &next_token in next_tokens { + if !visited.contains(next_token) { + queue.push_back(next_token); + } + } + } + } + + // If we get here, either checked_token wasn't reached or not all tokens were visited + if !visited.contains(checked_token) { + Err(EncodingError::InvalidInput( + "Checked token is not reachable through swap path".to_string(), + )) + } else { + Err(EncodingError::InvalidInput( + "Some tokens are not connected to the main path".to_string(), + )) + } + } } impl EVMStrategyEncoder for SplitSwapStrategyEncoder {} impl StrategyEncoder for SplitSwapStrategyEncoder { @@ -75,8 +214,8 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { solution: Solution, router_address: Bytes, ) -> Result<(Vec, Bytes), EncodingError> { - validate_swaps(&solution.swaps)?; - validate_token_path_connectivity( + self.validate_swaps(&solution.swaps)?; + self.validate_token_path_connectivity( &solution.swaps, &solution.given_token, &solution.checked_token, @@ -213,144 +352,6 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { } } -fn validate_swaps(swaps: &[Swap]) -> Result<(), EncodingError> { - for swap in swaps { - if swap.split >= 1.0 { - return Err(EncodingError::InvalidInput(format!( - "Split percentage must be less than 1 (100%), got {}", - swap.split - ))); - } - } - - let mut swaps_by_token: HashMap> = HashMap::new(); - - for swap in swaps { - swaps_by_token - .entry(swap.token_in.clone()) - .or_default() - .push(swap); - } - - for (token, token_swaps) in swaps_by_token { - if token_swaps.is_empty() { - return Err(EncodingError::InvalidInput(format!( - "No swaps found for token {:?}", - token - ))); - } - - // Single swaps don't need remainder handling - if token_swaps.len() == 1 { - if token_swaps[0].split != 0.0 { - return Err(EncodingError::InvalidInput(format!( - "Single swap must have 0% split for token {:?}", - token - ))); - } - continue; - } - - // Check if exactly one swap has 0% split and it's the last one - let mut found_zero_split = false; - for (i, swap) in token_swaps.iter().enumerate() { - match (swap.split == 0.0, i == token_swaps.len() - 1) { - (true, false) => { - return Err(EncodingError::InvalidInput(format!( - "The 0% split for token {:?} must be the last swap", - token - ))) - } - (true, true) => found_zero_split = true, - (false, _) => (), - } - } - - if !found_zero_split { - return Err(EncodingError::InvalidInput(format!( - "Token {:?} must have exactly one 0% split for remainder handling", - token - ))); - } - - // Sum non-zero splits and validate each is >0% and <100% - let mut total_percentage = 0.0; - for swap in token_swaps - .iter() - .take(token_swaps.len() - 1) - { - if swap.split <= 0.0 { - return Err(EncodingError::InvalidInput(format!( - "Non-remainder splits must be >0% for token {:?}", - token - ))); - } - total_percentage += swap.split; - } - - // Total must be <100% to leave room for remainder - if total_percentage >= 1.0 { - return Err(EncodingError::InvalidInput(format!( - "Total of non-remainder splits for token {:?} must be <100%, got {}%", - token, - total_percentage * 100.0 - ))); - } - } - - Ok(()) -} - -fn validate_token_path_connectivity( - swaps: &[Swap], - given_token: &Bytes, - checked_token: &Bytes, -) -> Result<(), EncodingError> { - // Build directed graph of token flows - let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new(); - for swap in swaps { - graph - .entry(&swap.token_in) - .or_default() - .insert(&swap.token_out); - } - - // BFS from given_token - let mut visited = HashSet::new(); - let mut queue = VecDeque::new(); - queue.push_back(given_token); - - while let Some(token) = queue.pop_front() { - if !visited.insert(token) { - continue; - } - - // Early success check - if token == checked_token && visited.len() == graph.len() + 1 { - return Ok(()); - } - - if let Some(next_tokens) = graph.get(token) { - for &next_token in next_tokens { - if !visited.contains(next_token) { - queue.push_back(next_token); - } - } - } - } - - // If we get here, either checked_token wasn't reached or not all tokens were visited - if !visited.contains(checked_token) { - Err(EncodingError::InvalidInput( - "Checked token is not reachable through swap path".to_string(), - )) - } else { - Err(EncodingError::InvalidInput( - "Some tokens are not connected to the main path".to_string(), - )) - } -} - /// This strategy encoder is used for solutions that are sent directly to the pool. /// Only 1 solution with 1 swap is supported. pub struct ExecutorStrategyEncoder {} @@ -780,8 +781,15 @@ mod tests { println!("{}", _hex_calldata); } + fn get_mock_split_swap_strategy_encoder() -> SplitSwapStrategyEncoder { + let private_key = + "0x123456789abcdef123456789abcdef123456789abcdef123456789abcdef1234".to_string(); + SplitSwapStrategyEncoder::new(private_key, Chain::Ethereum).unwrap() + } + #[test] fn test_validate_token_path_connectivity_single_swap() { + let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); let swaps = vec![Swap { @@ -794,12 +802,13 @@ mod tests { token_out: dai.clone(), split: 0f64, }]; - let result = validate_token_path_connectivity(&swaps, &weth, &dai); + let result = encoder.validate_token_path_connectivity(&swaps, &weth, &dai); assert_eq!(result, Ok(())); } #[test] fn test_validate_token_path_connectivity_multiple_swaps() { + let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(); @@ -825,12 +834,13 @@ mod tests { split: 0f64, }, ]; - let result = validate_token_path_connectivity(&swaps, &weth, &usdc); + let result = encoder.validate_token_path_connectivity(&swaps, &weth, &usdc); assert_eq!(result, Ok(())); } #[test] fn test_validate_token_path_connectivity_multiple_swaps_failure() { + let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(); @@ -860,7 +870,7 @@ mod tests { split: 0.0, }, ]; - let result = validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc); + let result = encoder.validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -877,7 +887,7 @@ mod tests { token_out: dai.clone(), split: 1.0, }]; - let result = validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc); + let result = encoder.validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -885,7 +895,7 @@ mod tests { // Test case 3: Empty swaps let empty_swaps: Vec = vec![]; - let result = validate_token_path_connectivity(&empty_swaps, &weth, &usdc); + let result = encoder.validate_token_path_connectivity(&empty_swaps, &weth, &usdc); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -894,6 +904,7 @@ mod tests { #[test] fn test_validate_swaps_single_swap() { + let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); let swaps = vec![Swap { @@ -906,7 +917,7 @@ mod tests { token_out: dai.clone(), split: 0f64, }]; - let result = validate_swaps(&swaps); + let result = encoder.validate_swaps(&swaps); assert_eq!(result, Ok(())); } @@ -948,7 +959,10 @@ mod tests { split: 0.0, // Remainder (20%) }, ]; - assert!(validate_swaps(&valid_swaps).is_ok()); + let encoder = get_mock_split_swap_strategy_encoder(); + assert!(encoder + .validate_swaps(&valid_swaps) + .is_ok()); } #[test] @@ -979,8 +993,9 @@ mod tests { split: 0.3, }, ]; + let encoder = get_mock_split_swap_strategy_encoder(); assert!(matches!( - validate_swaps(&invalid_total_swaps), + encoder.validate_swaps(&invalid_total_swaps), Err(EncodingError::InvalidInput(msg)) if msg.contains("must have exactly one 0% split") )); @@ -1008,7 +1023,7 @@ mod tests { }, ]; assert!(matches!( - validate_swaps(&invalid_zero_position_swaps), + encoder.validate_swaps(&invalid_zero_position_swaps), Err(EncodingError::InvalidInput(msg)) if msg.contains("must be the last swap") )); @@ -1045,8 +1060,9 @@ mod tests { split: 0.0, }, ]; + let encoder = get_mock_split_swap_strategy_encoder(); assert!(matches!( - validate_swaps(&invalid_overflow_swaps), + encoder.validate_swaps(&invalid_overflow_swaps), Err(EncodingError::InvalidInput(msg)) if msg.contains("must be <100%") )); }