refactor: deduplicate validation for Split & USV4 strategy

We use the same validation for both SplitSwapStrategy and UniswapV4Strategy - so, instead, put the validation into a Validator trait that gets initialized in both swap strategies in order to avoid duplication.
This commit is contained in:
TAMARA LIPOWSKI
2025-02-13 15:30:09 -05:00
parent f32210bb1f
commit b452372714

View File

@@ -81,24 +81,14 @@ pub struct SplitSwapStrategyEncoder {
selector: String,
native_address: Bytes,
wrapped_address: Bytes,
split_swap_validator: SplitSwapValidator,
}
impl SplitSwapStrategyEncoder {
pub fn new(
signer_pk: String,
chain: Chain,
swap_encoder_registry: SwapEncoderRegistry,
) -> Result<Self, EncodingError> {
let selector = "swap(uint256,address,address,uint256,bool,bool,uint256,address,((address,uint160,uint48,uint48),address,uint256),bytes,bytes)".to_string();
Ok(Self {
permit2: Permit2::new(signer_pk, chain.clone())?,
selector,
swap_encoder_registry,
native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?,
})
}
/// Validates whether a sequence of split swaps represents a valid solution.
#[derive(Clone)]
pub struct SplitSwapValidator;
impl SplitSwapValidator {
/// Raises an error if the split percentages are invalid.
///
/// Split percentages are considered valid if all the following conditions are met:
@@ -193,20 +183,22 @@ impl SplitSwapStrategyEncoder {
given_token: &Bytes,
checked_token: &Bytes,
native_action: &Option<NativeAction>,
native_address: &Bytes,
wrapped_address: &Bytes,
) -> Result<(), EncodingError> {
// Convert ETH to WETH only if there's a corresponding wrap/unwrap action
let given_token = if *given_token == *self.native_address {
let given_token = if *given_token == *native_address {
match native_action {
Some(NativeAction::Wrap) => &self.wrapped_address,
Some(NativeAction::Wrap) => wrapped_address,
_ => given_token,
}
} else {
given_token
};
let checked_token = if *checked_token == *self.native_address {
let checked_token = if *checked_token == *native_address {
match native_action {
Some(NativeAction::Unwrap) => &self.wrapped_address,
Some(NativeAction::Unwrap) => wrapped_address,
_ => checked_token,
}
} else {
@@ -259,6 +251,24 @@ impl SplitSwapStrategyEncoder {
}
}
impl SplitSwapStrategyEncoder {
pub fn new(
signer_pk: String,
chain: Chain,
swap_encoder_registry: SwapEncoderRegistry,
) -> Result<Self, EncodingError> {
let selector = "swap(uint256,address,address,uint256,bool,bool,uint256,address,((address,uint160,uint48,uint48),address,uint256),bytes,bytes)".to_string();
Ok(Self {
permit2: Permit2::new(signer_pk, chain.clone())?,
selector,
swap_encoder_registry,
native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?,
split_swap_validator: SplitSwapValidator,
})
}
}
/// To be used if there are two or more UniswapV4 swaps consecutively. They can be combined as a
/// gas optimization.
#[derive(Clone)]
@@ -268,6 +278,7 @@ pub struct UniswapV4StrategyEncoder {
selector: String,
native_address: Bytes,
wrapped_address: Bytes,
split_swap_validator: SplitSwapValidator,
}
impl EVMStrategyEncoder for UniswapV4StrategyEncoder {}
@@ -277,13 +288,17 @@ impl StrategyEncoder for UniswapV4StrategyEncoder {
&self,
solution: Solution,
) -> Result<(Vec<u8>, Bytes, Option<String>), EncodingError> {
self.validate_split_percentages(&solution.swaps)?;
self.validate_swap_path(
&solution.swaps,
&solution.given_token,
&solution.checked_token,
&solution.native_action,
)?;
self.split_swap_validator
.validate_split_percentages(&solution.swaps)?;
self.split_swap_validator
.validate_swap_path(
&solution.swaps,
&solution.given_token,
&solution.checked_token,
&solution.native_action,
&self.native_address,
&self.wrapped_address,
)?;
let (permit, signature) = self.permit2.get_permit(
&solution.router_address,
&solution.sender,
@@ -447,167 +462,9 @@ impl UniswapV4StrategyEncoder {
swap_encoder_registry,
native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?,
split_swap_validator: SplitSwapValidator,
})
}
/// Raises an error if the split percentages are invalid.
///
/// Split percentages are considered valid if all the following conditions are met:
/// * Each split amount is < 1 (100%)
/// * There is exactly one 0% split for each token, and it's the last swap specified, signifying
/// to the router to send the remainder of the token to the designated protocol
/// * The sum of all non-remainder splits for each token is < 1 (100%)
/// * There are no negative split amounts
fn validate_split_percentages(&self, swaps: &[Swap]) -> Result<(), EncodingError> {
let mut swaps_by_token: HashMap<Bytes, Vec<&Swap>> = HashMap::new();
for swap in swaps {
if swap.split >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
"Split percentage must be less than 1 (100%), got {}",
swap.split
)));
}
swaps_by_token
.entry(swap.token_in.clone())
.or_default()
.push(swap);
}
for (token, token_swaps) in swaps_by_token {
// Single swaps don't need remainder handling
if token_swaps.len() == 1 {
if token_swaps[0].split != 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Single swap must have 0% split for token {:?}",
token
)));
}
continue;
}
let mut found_zero_split = false;
let mut total_percentage = 0.0;
for (i, swap) in token_swaps.iter().enumerate() {
match (swap.split == 0.0, i == token_swaps.len() - 1) {
(true, false) => {
return Err(EncodingError::InvalidInput(format!(
"The 0% split for token {:?} must be the last swap",
token
)))
}
(true, true) => found_zero_split = true,
(false, _) => {
if swap.split < 0.0 {
return Err(EncodingError::InvalidInput(format!(
"All splits must be >= 0% for token {:?}",
token
)));
}
total_percentage += swap.split;
}
}
}
if !found_zero_split {
return Err(EncodingError::InvalidInput(format!(
"Token {:?} must have exactly one 0% split for remainder handling",
token
)));
}
// Total must be <100% to leave room for remainder
if total_percentage >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
"Total of non-remainder splits for token {:?} must be <100%, got {}%",
token,
total_percentage * 100.0
)));
}
}
Ok(())
}
/// Raises an error if swaps do not represent a valid path from the given token to the checked
/// token.
///
/// A path is considered valid if all the following conditions are met:
/// * The checked token is reachable from the given token through the swap path
/// * There are no tokens which are unconnected from the main path
///
/// If the given token is the native token and the native action is WRAP, it will be converted
/// to the wrapped token before validating the swap path. The same principle applies for the
/// checked token and the UNWRAP action.
fn validate_swap_path(
&self,
swaps: &[Swap],
given_token: &Bytes,
checked_token: &Bytes,
native_action: &Option<NativeAction>,
) -> Result<(), EncodingError> {
// Convert ETH to WETH only if there's a corresponding wrap/unwrap action
let given_token = if *given_token == *self.native_address {
match native_action {
Some(NativeAction::Wrap) => &self.wrapped_address,
_ => given_token,
}
} else {
given_token
};
let checked_token = if *checked_token == *self.native_address {
match native_action {
Some(NativeAction::Unwrap) => &self.wrapped_address,
_ => checked_token,
}
} else {
checked_token
};
// Build directed graph of token flows
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
for swap in swaps {
graph
.entry(&swap.token_in)
.or_default()
.insert(&swap.token_out);
}
// BFS from validation_given
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(given_token);
while let Some(token) = queue.pop_front() {
if !visited.insert(token) {
continue;
}
// Early success check
if token == checked_token && visited.len() == graph.len() + 1 {
return Ok(());
}
if let Some(next_tokens) = graph.get(token) {
for &next_token in next_tokens {
if !visited.contains(next_token) {
queue.push_back(next_token);
}
}
}
}
// If we get here, either checked_token wasn't reached or not all tokens were visited
if !visited.contains(checked_token) {
Err(EncodingError::InvalidInput(
"Checked token is not reachable through swap path".to_string(),
))
} else {
Err(EncodingError::InvalidInput(
"Some tokens are not connected to the main path".to_string(),
))
}
}
}
impl EVMStrategyEncoder for SplitSwapStrategyEncoder {}
@@ -617,13 +474,17 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
&self,
solution: Solution,
) -> Result<(Vec<u8>, Bytes, Option<String>), EncodingError> {
self.validate_split_percentages(&solution.swaps)?;
self.validate_swap_path(
&solution.swaps,
&solution.given_token,
&solution.checked_token,
&solution.native_action,
)?;
self.split_swap_validator
.validate_split_percentages(&solution.swaps)?;
self.split_swap_validator
.validate_swap_path(
&solution.swaps,
&solution.given_token,
&solution.checked_token,
&solution.native_action,
&self.native_address,
&self.wrapped_address,
)?;
let (permit, signature) = self.permit2.get_permit(
&solution.router_address,
&solution.sender,
@@ -1204,16 +1065,10 @@ mod tests {
println!("{}", _hex_calldata);
}
fn get_mock_split_swap_strategy_encoder() -> SplitSwapStrategyEncoder {
let private_key =
"0x123456789abcdef123456789abcdef123456789abcdef123456789abcdef1234".to_string();
let swap_encoder_registry = get_swap_encoder_registry();
SplitSwapStrategyEncoder::new(private_key, eth_chain(), swap_encoder_registry).unwrap()
}
#[test]
fn test_validate_path_single_swap() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swaps = vec![Swap {
@@ -1226,13 +1081,14 @@ mod tests {
token_out: dai.clone(),
split: 0f64,
}];
let result = encoder.validate_swap_path(&swaps, &weth, &dai, &None);
let result = validator.validate_swap_path(&swaps, &weth, &dai, &None, &eth, &weth);
assert_eq!(result, Ok(()));
}
#[test]
fn test_validate_path_multiple_swaps() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
@@ -1258,13 +1114,14 @@ mod tests {
split: 0f64,
},
];
let result = encoder.validate_swap_path(&swaps, &weth, &usdc, &None);
let result = validator.validate_swap_path(&swaps, &weth, &usdc, &None, &eth, &weth);
assert_eq!(result, Ok(()));
}
#[test]
fn test_validate_path_disconnected() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
@@ -1293,7 +1150,8 @@ mod tests {
split: 0.0,
},
];
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc, &None);
let result =
validator.validate_swap_path(&disconnected_swaps, &weth, &usdc, &None, &eth, &weth);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -1302,7 +1160,8 @@ mod tests {
#[test]
fn test_validate_path_unreachable_checked_token() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
@@ -1317,7 +1176,8 @@ mod tests {
token_out: dai.clone(),
split: 1.0,
}];
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc, &None);
let result =
validator.validate_swap_path(&unreachable_swaps, &weth, &usdc, &None, &eth, &weth);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -1326,12 +1186,13 @@ mod tests {
#[test]
fn test_validate_path_empty_swaps() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let empty_swaps: Vec<Swap> = vec![];
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc, &None);
let result = validator.validate_swap_path(&empty_swaps, &weth, &usdc, &None, &eth, &weth);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -1340,7 +1201,7 @@ mod tests {
#[test]
fn test_validate_swap_single() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swaps = vec![Swap {
@@ -1353,12 +1214,13 @@ mod tests {
token_out: dai.clone(),
split: 0f64,
}];
let result = encoder.validate_split_percentages(&swaps);
let result = validator.validate_split_percentages(&swaps);
assert_eq!(result, Ok(()));
}
#[test]
fn test_validate_swaps_multiple() {
let validator = SplitSwapValidator;
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
@@ -1395,14 +1257,14 @@ mod tests {
split: 0.0, // Remainder (20%)
},
];
let encoder = get_mock_split_swap_strategy_encoder();
assert!(encoder
assert!(validator
.validate_split_percentages(&valid_swaps)
.is_ok());
}
#[test]
fn test_validate_swaps_no_remainder_split() {
let validator = SplitSwapValidator;
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
@@ -1428,15 +1290,15 @@ mod tests {
split: 0.3,
},
];
let encoder = get_mock_split_swap_strategy_encoder();
assert!(matches!(
encoder.validate_split_percentages(&invalid_total_swaps),
validator.validate_split_percentages(&invalid_total_swaps),
Err(EncodingError::InvalidInput(msg)) if msg.contains("must have exactly one 0% split")
));
}
#[test]
fn test_validate_swaps_zero_split_not_at_end() {
let validator = SplitSwapValidator;
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
@@ -1462,15 +1324,15 @@ mod tests {
split: 0.5,
},
];
let encoder = get_mock_split_swap_strategy_encoder();
assert!(matches!(
encoder.validate_split_percentages(&invalid_zero_position_swaps),
validator.validate_split_percentages(&invalid_zero_position_swaps),
Err(EncodingError::InvalidInput(msg)) if msg.contains("must be the last swap")
));
}
#[test]
fn test_validate_swaps_splits_exceed_hundred_percent() {
let validator = SplitSwapValidator;
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
@@ -1506,17 +1368,15 @@ mod tests {
split: 0.0,
},
];
let encoder = get_mock_split_swap_strategy_encoder();
assert!(matches!(
encoder.validate_split_percentages(&invalid_overflow_swaps),
validator.validate_split_percentages(&invalid_overflow_swaps),
Err(EncodingError::InvalidInput(msg)) if msg.contains("must be <100%")
));
}
#[test]
fn test_validate_path_wrap_eth_given_token() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let weth = Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap();
@@ -1532,14 +1392,20 @@ mod tests {
split: 0f64,
}];
let result = encoder.validate_swap_path(&swaps, &eth, &usdc, &Some(NativeAction::Wrap));
let result = validator.validate_swap_path(
&swaps,
&eth,
&usdc,
&Some(NativeAction::Wrap),
&eth,
&weth,
);
assert_eq!(result, Ok(()));
}
#[test]
fn test_validate_token_path_connectivity_wrap_eth_checked_token() {
let encoder = get_mock_split_swap_strategy_encoder();
let validator = SplitSwapValidator;
let eth = Bytes::from_str("0x0000000000000000000000000000000000000000").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let weth = Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap();
@@ -1555,7 +1421,14 @@ mod tests {
split: 0f64,
}];
let result = encoder.validate_swap_path(&swaps, &usdc, &eth, &Some(NativeAction::Unwrap));
let result = validator.validate_swap_path(
&swaps,
&usdc,
&eth,
&Some(NativeAction::Unwrap),
&eth,
&weth,
);
assert_eq!(result, Ok(()));
}
}