feat: Support swap grouping for executor strategy

- Also make sure the strategy errors if > 1 group.
This commit is contained in:
TAMARA LIPOWSKI
2025-02-18 12:50:38 -05:00
parent 962e460e34
commit ac831176d4

View File

@@ -261,31 +261,47 @@ impl StrategyEncoder for ExecutorStrategyEncoder {
&self,
solution: Solution,
) -> Result<(Vec<u8>, Bytes, Option<String>), EncodingError> {
let swap = solution
.swaps
let grouped_swaps = group_swaps(solution.clone().swaps);
let number_of_groups = grouped_swaps.len();
if number_of_groups > 1 {
return Err(EncodingError::InvalidInput(format!(
"Executor strategy only supports one swap for non-groupable protocols. Found {}",
number_of_groups
)))
}
let grouped_swap = grouped_swaps
.first()
.ok_or_else(|| EncodingError::InvalidInput("No swaps found in solution".to_string()))?;
.ok_or_else(|| EncodingError::FatalError("Swap grouping failed".to_string()))?;
let receiver = solution.receiver;
let router_address = solution.router_address;
let swap_encoder = self
.get_swap_encoder(&swap.component.protocol_system)
.get_swap_encoder(&grouped_swap.protocol_system)
.ok_or_else(|| {
EncodingError::InvalidInput(format!(
"Swap encoder not found for protocol: {}",
swap.component.protocol_system
grouped_swap.protocol_system
))
})?;
let encoding_context = EncodingContext {
receiver: solution.receiver,
exact_out: solution.exact_out,
router_address: solution.router_address,
};
let protocol_data = swap_encoder.encode_swap(swap.clone(), encoding_context)?;
let mut grouped_protocol_data: Vec<Vec<u8>> = vec![];
for swap in grouped_swap.swaps.iter() {
let encoding_context = EncodingContext {
receiver: receiver.clone(),
exact_out: solution.exact_out,
router_address: router_address.clone(),
};
let protocol_data = swap_encoder.encode_swap(swap.clone(), encoding_context.clone())?;
grouped_protocol_data.push(protocol_data);
}
let executor_address = Bytes::from_str(swap_encoder.executor_address())
.map_err(|_| EncodingError::FatalError("Invalid executor address".to_string()))?;
Ok((
protocol_data,
grouped_protocol_data.abi_encode_packed(),
executor_address,
Some(
swap_encoder
@@ -397,6 +413,124 @@ mod tests {
);
assert_eq!(selector, Some("swap(uint256,bytes)".to_string()));
}
#[test]
fn test_executor_strategy_encode_too_many_swaps() {
let swap_encoder_registry = get_swap_encoder_registry();
let encoder = ExecutorStrategyEncoder::new(swap_encoder_registry);
let token_in = weth();
let token_out = Bytes::from("0x6b175474e89094c44da98b954eedeac495271d0f");
let swap = Swap {
component: ProtocolComponent {
protocol_system: "uniswap_v2".to_string(),
..Default::default()
},
token_in: token_in.clone(),
token_out: token_out.clone(),
split: 0f64,
};
let solution = Solution {
exact_out: false,
given_token: token_in,
given_amount: BigUint::from(1000000000000000000u64),
expected_amount: Some(BigUint::from(1000000000000000000u64)),
checked_token: token_out,
checked_amount: None,
sender: Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(),
receiver: Bytes::from_str("0x1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e").unwrap(),
swaps: vec![swap.clone(), swap],
direct_execution: true,
router_address: Bytes::from_str("0x3Ede3eCa2a72B3aeCC820E955B36f38437D01395").unwrap(),
slippage: None,
native_action: None,
};
let result = encoder.encode_strategy(solution);
assert!(result.is_err());
}
#[test]
fn test_executor_strategy_encode_grouped_swaps() {
let swap_encoder_registry = get_swap_encoder_registry();
let encoder = ExecutorStrategyEncoder::new(swap_encoder_registry);
let weth = weth();
let dai = Bytes::from("0x6b175474e89094c44da98b954eedeac495271d0f");
let usdc = Bytes::from("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48");
let swap_a = Swap {
component: ProtocolComponent {
id: "0xA478c2975Ab1Ea89e8196811F51A7B7Ade33eB11".to_string(),
protocol_system: "uniswap_v4".to_string(),
..Default::default()
},
token_in: weth.clone(),
token_out: dai.clone(),
split: 0f64,
};
let swap_b = Swap {
component: ProtocolComponent {
id: "0xAE461cA67B15dc8dc81CE7615e0320dA1A9aB8D5".to_string(),
protocol_system: "uniswap_v4".to_string(),
..Default::default()
},
token_in: dai.clone(),
token_out: usdc.clone(),
split: 0f64,
};
let solution = Solution {
exact_out: false,
given_token: weth,
given_amount: BigUint::from(1000000000000000000u64),
expected_amount: Some(BigUint::from(1000000000000000000u64)),
checked_token: usdc,
checked_amount: None,
sender: Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap(),
// The receiver was generated with `makeAddr("bob") using forge`
receiver: Bytes::from_str("0x1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e").unwrap(),
swaps: vec![swap_a, swap_b],
direct_execution: true,
router_address: Bytes::from_str("0x3Ede3eCa2a72B3aeCC820E955B36f38437D01395").unwrap(),
slippage: None,
native_action: None,
};
let (protocol_data, executor_address, selector) = encoder
.encode_strategy(solution)
.unwrap();
let hex_protocol_data = encode(&protocol_data);
assert_eq!(
executor_address,
Bytes::from_str("0x5c2f5a71f67c01775180adc06909288b4c329308").unwrap()
);
assert_eq!(
hex_protocol_data,
String::from(concat!(
// in token
"c02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
// component id
"a478c2975ab1ea89e8196811f51a7b7ade33eb11",
// receiver
"1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e",
// zero for one
"00",
// in token
"6b175474e89094c44da98b954eedeac495271d0f",
// component id
"ae461ca67b15dc8dc81ce7615e0320da1a9ab8d5",
// receiver
"1d96f2f6bef1202e4ce1ff6dad0c2cb002861d3e",
// zero for one
"01",
))
);
assert_eq!(selector, Some("swap(uint256,bytes)".to_string()));
}
#[rstest]
#[case::no_check_no_slippage(
None,