diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index ec19d1b..c0f422a 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -13,7 +13,7 @@ use crate::encoding::{ errors::EncodingError, evm::{ approvals::permit2::Permit2, - constants::WETH_ADDRESS, + constants::{NATIVE_ADDRESS, WETH_ADDRESS}, swap_encoder::SWAP_ENCODER_REGISTRY, utils::{biguint_to_u256, bytes_to_address, encode_input, percentage_to_uint24}, }, @@ -68,7 +68,7 @@ impl SplitSwapStrategyEncoder { Ok(Self { permit2: Permit2::new(signer_pk, chain)?, selector }) } - fn validate_swaps(&self, swaps: &[Swap]) -> Result<(), EncodingError> { + fn validate_split_percentages(&self, swaps: &[Swap]) -> Result<(), EncodingError> { let mut swaps_by_token: HashMap> = HashMap::new(); for swap in swaps { if swap.split >= 1.0 { @@ -145,7 +145,7 @@ impl SplitSwapStrategyEncoder { Ok(()) } - fn validate_token_path_connectivity( + fn validate_swap_path( &self, swaps: &[Swap], given_token: &Bytes, @@ -153,16 +153,10 @@ impl SplitSwapStrategyEncoder { ) -> Result<(), EncodingError> { // Special case: If given_token is ETH or checked_token is ETH, treat it as WETH for path // validation - let eth_address = Bytes::from_str("0x0000000000000000000000000000000000000000") - .map_err(|_| EncodingError::FatalError("Invalid ETH address".to_string()))?; - let weth_address = Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") - .map_err(|_| EncodingError::FatalError("Invalid WETH address".to_string()))?; + let given_token = if *given_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { given_token }; - let validation_given = - if given_token == ð_address { &weth_address } else { given_token }; - - let validation_checked = - if checked_token == ð_address { &weth_address } else { checked_token }; + let checked_token = + if *checked_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { checked_token }; // Build directed graph of token flows let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new(); @@ -176,7 +170,7 @@ impl SplitSwapStrategyEncoder { // BFS from validation_given let mut visited = HashSet::new(); let mut queue = VecDeque::new(); - queue.push_back(validation_given); + queue.push_back(given_token); while let Some(token) = queue.pop_front() { if !visited.insert(token) { @@ -184,7 +178,7 @@ impl SplitSwapStrategyEncoder { } // Early success check - if token == validation_checked && visited.len() == graph.len() + 1 { + if token == checked_token && visited.len() == graph.len() + 1 { return Ok(()); } @@ -198,7 +192,7 @@ impl SplitSwapStrategyEncoder { } // If we get here, either checked_token wasn't reached or not all tokens were visited - if !visited.contains(validation_checked) { + if !visited.contains(checked_token) { Err(EncodingError::InvalidInput( "Checked token is not reachable through swap path".to_string(), )) @@ -216,12 +210,8 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { solution: Solution, router_address: Bytes, ) -> Result<(Vec, Bytes), EncodingError> { - self.validate_swaps(&solution.swaps)?; - self.validate_token_path_connectivity( - &solution.swaps, - &solution.given_token, - &solution.checked_token, - )?; + self.validate_split_percentages(&solution.swaps)?; + self.validate_swap_path(&solution.swaps, &solution.given_token, &solution.checked_token)?; let (permit, signature) = self.permit2.get_permit( &router_address, &solution.sender, @@ -790,7 +780,7 @@ mod tests { } #[test] - fn test_validate_token_path_connectivity_single_swap() { + fn test_validate_path_single_swap() { let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); @@ -804,12 +794,12 @@ mod tests { token_out: dai.clone(), split: 0f64, }]; - let result = encoder.validate_token_path_connectivity(&swaps, &weth, &dai); + let result = encoder.validate_swap_path(&swaps, &weth, &dai); assert_eq!(result, Ok(())); } #[test] - fn test_validate_token_path_connectivity_multiple_swaps() { + fn test_validate_path_multiple_swaps() { let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); @@ -836,12 +826,12 @@ mod tests { split: 0f64, }, ]; - let result = encoder.validate_token_path_connectivity(&swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&swaps, &weth, &usdc); assert_eq!(result, Ok(())); } #[test] - fn test_validate_token_path_connectivity_disconnected_path() { + fn test_validate_path_disconnected() { let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); @@ -871,7 +861,7 @@ mod tests { split: 0.0, }, ]; - let result = encoder.validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -879,7 +869,7 @@ mod tests { } #[test] - fn test_validate_token_path_connectivity_unreachable_checked_token() { + fn test_validate_path_unreachable_checked_token() { let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(); @@ -895,7 +885,7 @@ mod tests { token_out: dai.clone(), split: 1.0, }]; - let result = encoder.validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -903,13 +893,13 @@ mod tests { } #[test] - fn test_validate_token_path_connectivity_empty_swaps() { + fn test_validate_path_empty_swaps() { let encoder = get_mock_split_swap_strategy_encoder(); let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(); let empty_swaps: Vec = vec![]; - let result = encoder.validate_token_path_connectivity(&empty_swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -931,7 +921,7 @@ mod tests { token_out: dai.clone(), split: 0f64, }]; - let result = encoder.validate_swaps(&swaps); + let result = encoder.validate_split_percentages(&swaps); assert_eq!(result, Ok(())); } @@ -975,7 +965,7 @@ mod tests { ]; let encoder = get_mock_split_swap_strategy_encoder(); assert!(encoder - .validate_swaps(&valid_swaps) + .validate_split_percentages(&valid_swaps) .is_ok()); } @@ -1008,7 +998,7 @@ mod tests { ]; let encoder = get_mock_split_swap_strategy_encoder(); assert!(matches!( - encoder.validate_swaps(&invalid_total_swaps), + encoder.validate_split_percentages(&invalid_total_swaps), Err(EncodingError::InvalidInput(msg)) if msg.contains("must have exactly one 0% split") )); } @@ -1042,7 +1032,7 @@ mod tests { ]; let encoder = get_mock_split_swap_strategy_encoder(); assert!(matches!( - encoder.validate_swaps(&invalid_zero_position_swaps), + encoder.validate_split_percentages(&invalid_zero_position_swaps), Err(EncodingError::InvalidInput(msg)) if msg.contains("must be the last swap") )); } @@ -1086,13 +1076,13 @@ mod tests { ]; let encoder = get_mock_split_swap_strategy_encoder(); assert!(matches!( - encoder.validate_swaps(&invalid_overflow_swaps), + encoder.validate_split_percentages(&invalid_overflow_swaps), Err(EncodingError::InvalidInput(msg)) if msg.contains("must be <100%") )); } #[test] - fn test_validate_token_path_connectivity_wrap_eth_i() { + fn test_validate_path_wrap_eth_given_token() { let encoder = get_mock_split_swap_strategy_encoder(); let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(); @@ -1110,7 +1100,7 @@ mod tests { split: 0f64, }]; - let result = encoder.validate_token_path_connectivity(&swaps, ð, &usdc); + let result = encoder.validate_swap_path(&swaps, ð, &usdc); assert_eq!(result, Ok(())); } @@ -1133,7 +1123,7 @@ mod tests { split: 0f64, }]; - let result = encoder.validate_token_path_connectivity(&swaps, &usdc, ð); + let result = encoder.validate_swap_path(&swaps, &usdc, ð); assert_eq!(result, Ok(())); } }