diff --git a/src/encoding/evm/strategy_encoder/strategy_encoders.rs b/src/encoding/evm/strategy_encoder/strategy_encoders.rs index 0632bad..2936280 100644 --- a/src/encoding/evm/strategy_encoder/strategy_encoders.rs +++ b/src/encoding/evm/strategy_encoder/strategy_encoders.rs @@ -261,31 +261,47 @@ impl StrategyEncoder for ExecutorStrategyEncoder { &self, solution: Solution, ) -> Result<(Vec, Bytes, Option), EncodingError> { - let swap = solution - .swaps + let grouped_swaps = group_swaps(solution.clone().swaps); + let number_of_groups = grouped_swaps.len(); + if number_of_groups > 1 { + return Err(EncodingError::InvalidInput(format!( + "Executor strategy only supports one swap for non-groupable protocols. Found {}", + number_of_groups + ))) + } + + let grouped_swap = grouped_swaps .first() - .ok_or_else(|| EncodingError::InvalidInput("No swaps found in solution".to_string()))?; + .ok_or_else(|| EncodingError::FatalError("Swap grouping failed".to_string()))?; + + let receiver = solution.receiver; + let router_address = solution.router_address; let swap_encoder = self - .get_swap_encoder(&swap.component.protocol_system) + .get_swap_encoder(&grouped_swap.protocol_system) .ok_or_else(|| { EncodingError::InvalidInput(format!( "Swap encoder not found for protocol: {}", - swap.component.protocol_system + grouped_swap.protocol_system )) })?; - let encoding_context = EncodingContext { - receiver: solution.receiver, - exact_out: solution.exact_out, - router_address: solution.router_address, - }; - let protocol_data = swap_encoder.encode_swap(swap.clone(), encoding_context)?; + let mut grouped_protocol_data: Vec> = vec![]; + for swap in grouped_swap.swaps.iter() { + let encoding_context = EncodingContext { + receiver: receiver.clone(), + exact_out: solution.exact_out, + router_address: router_address.clone(), + }; + let protocol_data = swap_encoder.encode_swap(swap.clone(), encoding_context.clone())?; + grouped_protocol_data.push(protocol_data); + } let executor_address = Bytes::from_str(swap_encoder.executor_address()) .map_err(|_| EncodingError::FatalError("Invalid executor address".to_string()))?; + Ok(( - protocol_data, + grouped_protocol_data.abi_encode_packed(), executor_address, Some( swap_encoder @@ -397,6 +413,124 @@ mod tests { ); assert_eq!(selector, Some("swap(uint256,bytes)".to_string())); } + + #[test] + fn test_executor_strategy_encode_too_many_swaps() { + let swap_encoder_registry = get_swap_encoder_registry(); + let encoder = ExecutorStrategyEncoder::new(swap_encoder_registry); + + let token_in = weth(); + let token_out = Bytes::from("0x6b175474e89094c44da98b954eedeac495271d0f"); + + let swap = Swap { + component: ProtocolComponent { + protocol_system: "uniswap_v2".to_string(), + ..Default::default() + }, + token_in: token_in.clone(), + token_out: token_out.clone(), + split: 0f64, + }; + + let solution = Solution { + exact_out: false, + given_token: token_in, + given_amount: BigUint::from(1000000000000000000u64), + expected_amount: Some(BigUint::from(1000000000000000000u64)), + checked_token: token_out, + checked_amount: None, + sender: Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(), + receiver: Bytes::from_str("0x1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e").unwrap(), + swaps: vec![swap.clone(), swap], + direct_execution: true, + router_address: Bytes::from_str("0x3Ede3eCa2a72B3aeCC820E955B36f38437D01395").unwrap(), + slippage: None, + native_action: None, + }; + + let result = encoder.encode_strategy(solution); + assert!(result.is_err()); + } + + #[test] + fn test_executor_strategy_encode_grouped_swaps() { + let swap_encoder_registry = get_swap_encoder_registry(); + let encoder = ExecutorStrategyEncoder::new(swap_encoder_registry); + + let weth = weth(); + let dai = Bytes::from("0x6b175474e89094c44da98b954eedeac495271d0f"); + let usdc = Bytes::from("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48"); + + let swap_a = Swap { + component: ProtocolComponent { + id: "0xA478c2975Ab1Ea89e8196811F51A7B7Ade33eB11".to_string(), + protocol_system: "uniswap_v4".to_string(), + ..Default::default() + }, + token_in: weth.clone(), + token_out: dai.clone(), + split: 0f64, + }; + let swap_b = Swap { + component: ProtocolComponent { + id: "0xAE461cA67B15dc8dc81CE7615e0320dA1A9aB8D5".to_string(), + protocol_system: "uniswap_v4".to_string(), + ..Default::default() + }, + token_in: dai.clone(), + token_out: usdc.clone(), + split: 0f64, + }; + + let solution = Solution { + exact_out: false, + given_token: weth, + given_amount: BigUint::from(1000000000000000000u64), + expected_amount: Some(BigUint::from(1000000000000000000u64)), + checked_token: usdc, + checked_amount: None, + sender: Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(), + // The receiver was generated with `makeAddr("bob") using forge` + receiver: Bytes::from_str("0x1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e").unwrap(), + swaps: vec![swap_a, swap_b], + direct_execution: true, + router_address: Bytes::from_str("0x3Ede3eCa2a72B3aeCC820E955B36f38437D01395").unwrap(), + slippage: None, + native_action: None, + }; + + let (protocol_data, executor_address, selector) = encoder + .encode_strategy(solution) + .unwrap(); + let hex_protocol_data = encode(&protocol_data); + assert_eq!( + executor_address, + Bytes::from_str("0x5c2f5a71f67c01775180adc06909288b4c329308").unwrap() + ); + assert_eq!( + hex_protocol_data, + String::from(concat!( + // in token + "c02aaa39b223fe8d0a0e5c4f27ead9083c756cc2", + // component id + "a478c2975ab1ea89e8196811f51a7b7ade33eb11", + // receiver + "1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e", + // zero for one + "00", + // in token + "6b175474e89094c44da98b954eedeac495271d0f", + // component id + "ae461ca67b15dc8dc81ce7615e0320da1a9ab8d5", + // receiver + "1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e", + // zero for one + "01", + )) + ); + assert_eq!(selector, Some("swap(uint256,bytes)".to_string())); + } + #[rstest] #[case::no_check_no_slippage( None,