diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index 620ecfa..508f0c0 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -95,8 +95,8 @@ impl SplitSwapStrategyEncoder { continue; } - // Check if exactly one swap has 0% split and it's the last one let mut found_zero_split = false; + let mut total_percentage = 0.0; for (i, swap) in token_swaps.iter().enumerate() { match (swap.split == 0.0, i == token_swaps.len() - 1) { (true, false) => { @@ -106,7 +106,15 @@ impl SplitSwapStrategyEncoder { ))) } (true, true) => found_zero_split = true, - (false, _) => (), + (false, _) => { + if swap.split <= 0.0 { + return Err(EncodingError::InvalidInput(format!( + "Non-remainder splits must be >0% for token {:?}", + token + ))); + } + total_percentage += swap.split; + } } } @@ -117,21 +125,6 @@ impl SplitSwapStrategyEncoder { ))); } - // Sum non-zero splits and validate each is >0% and <100% - let mut total_percentage = 0.0; - for swap in token_swaps - .iter() - .take(token_swaps.len() - 1) - { - if swap.split <= 0.0 { - return Err(EncodingError::InvalidInput(format!( - "Non-remainder splits must be >0% for token {:?}", - token - ))); - } - total_percentage += swap.split; - } - // Total must be <100% to leave room for remainder if total_percentage >= 1.0 { return Err(EncodingError::InvalidInput(format!( @@ -150,13 +143,26 @@ impl SplitSwapStrategyEncoder { swaps: &[Swap], given_token: &Bytes, checked_token: &Bytes, + native_action: &Option, ) -> Result<(), EncodingError> { - // Special case: If given_token is ETH or checked_token is ETH, treat it as WETH for path - // validation - let given_token = if *given_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { given_token }; + // Convert ETH to WETH only if there's a corresponding wrap/unwrap action + let given_token = if *given_token == *NATIVE_ADDRESS { + match native_action { + Some(NativeAction::Wrap) => &WETH_ADDRESS, + _ => given_token, + } + } else { + given_token + }; - let checked_token = - if *checked_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { checked_token }; + let checked_token = if *checked_token == *NATIVE_ADDRESS { + match native_action { + Some(NativeAction::Unwrap) => &WETH_ADDRESS, + _ => checked_token, + } + } else { + checked_token + }; // Build directed graph of token flows let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new(); @@ -211,7 +217,12 @@ impl StrategyEncoder for SplitSwapStrategyEncoder { router_address: Bytes, ) -> Result<(Vec, Bytes), EncodingError> { self.validate_split_percentages(&solution.swaps)?; - self.validate_swap_path(&solution.swaps, &solution.given_token, &solution.checked_token)?; + self.validate_swap_path( + &solution.swaps, + &solution.given_token, + &solution.checked_token, + &solution.native_action, + )?; let (permit, signature) = self.permit2.get_permit( &router_address, &solution.sender, @@ -794,7 +805,7 @@ mod tests { token_out: dai.clone(), split: 0f64, }]; - let result = encoder.validate_swap_path(&swaps, &weth, &dai); + let result = encoder.validate_swap_path(&swaps, &weth, &dai, &None); assert_eq!(result, Ok(())); } @@ -826,7 +837,7 @@ mod tests { split: 0f64, }, ]; - let result = encoder.validate_swap_path(&swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&swaps, &weth, &usdc, &None); assert_eq!(result, Ok(())); } @@ -861,7 +872,7 @@ mod tests { split: 0.0, }, ]; - let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc, &None); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -885,7 +896,7 @@ mod tests { token_out: dai.clone(), split: 1.0, }]; - let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc, &None); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -899,7 +910,7 @@ mod tests { let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(); let empty_swaps: Vec = vec![]; - let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc); + let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc, &None); assert!(matches!( result, Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path") @@ -1100,7 +1111,7 @@ mod tests { split: 0f64, }]; - let result = encoder.validate_swap_path(&swaps, ð, &usdc); + let result = encoder.validate_swap_path(&swaps, ð, &usdc, &Some(NativeAction::Wrap)); assert_eq!(result, Ok(())); } @@ -1123,7 +1134,7 @@ mod tests { split: 0f64, }]; - let result = encoder.validate_swap_path(&swaps, &usdc, ð); + let result = encoder.validate_swap_path(&swaps, &usdc, ð, &Some(NativeAction::Unwrap)); assert_eq!(result, Ok(())); } }