diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index 65ea61c..ec19d1b 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -151,6 +151,19 @@ impl SplitSwapStrategyEncoder { given_token: &Bytes, checked_token: &Bytes, ) -> Result<(), EncodingError> { + // Special case: If given_token is ETH or checked_token is ETH, treat it as WETH for path + // validation + let eth_address = Bytes::from_str("0x0000000000000000000000000000000000000000") + .map_err(|_| EncodingError::FatalError("Invalid ETH address".to_string()))?; + let weth_address = Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") + .map_err(|_| EncodingError::FatalError("Invalid WETH address".to_string()))?; + + let validation_given = + if given_token == ð_address { &weth_address } else { given_token }; + + let validation_checked = + if checked_token == ð_address { &weth_address } else { checked_token }; + // Build directed graph of token flows let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new(); for swap in swaps { @@ -160,10 +173,10 @@ impl SplitSwapStrategyEncoder { .insert(&swap.token_out); } - // BFS from given_token + // BFS from validation_given let mut visited = HashSet::new(); let mut queue = VecDeque::new(); - queue.push_back(given_token); + queue.push_back(validation_given); while let Some(token) = queue.pop_front() { if !visited.insert(token) { @@ -171,7 +184,7 @@ impl SplitSwapStrategyEncoder { } // Early success check - if token == checked_token && visited.len() == graph.len() + 1 { + if token == validation_checked && visited.len() == graph.len() + 1 { return Ok(()); } @@ -185,7 +198,7 @@ impl SplitSwapStrategyEncoder { } // If we get here, either checked_token wasn't reached or not all tokens were visited - if !visited.contains(checked_token) { + if !visited.contains(validation_checked) { Err(EncodingError::InvalidInput( "Checked token is not reachable through swap path".to_string(), )) @@ -1077,4 +1090,50 @@ mod tests { Err(EncodingError::InvalidInput(msg)) if msg.contains("must be <100%") )); } + + #[test] + fn test_validate_token_path_connectivity_wrap_eth_i() { + let encoder = get_mock_split_swap_strategy_encoder(); + + let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(); + let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(); + let weth = Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap(); + + let swaps = vec![Swap { + component: ProtocolComponent { + id: "pool1".to_string(), + protocol_system: "uniswap_v2".to_string(), + ..Default::default() + }, + token_in: weth.clone(), + token_out: usdc.clone(), + split: 0f64, + }]; + + let result = encoder.validate_token_path_connectivity(&swaps, ð, &usdc); + assert_eq!(result, Ok(())); + } + + #[test] + fn test_validate_token_path_connectivity_wrap_eth_checked_token() { + let encoder = get_mock_split_swap_strategy_encoder(); + + let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(); + let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(); + let weth = Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap(); + + let swaps = vec![Swap { + component: ProtocolComponent { + id: "pool1".to_string(), + protocol_system: "uniswap_v2".to_string(), + ..Default::default() + }, + token_in: usdc.clone(), + token_out: weth.clone(), + split: 0f64, + }]; + + let result = encoder.validate_token_path_connectivity(&swaps, &usdc, ð); + assert_eq!(result, Ok(())); + } }