chore: move validation methods inside impl

This commit is contained in:
royvardhan
2025-02-04 21:25:06 +05:30
parent b69aef9b8f
commit b8013c6e7e

View File

@@ -67,6 +67,145 @@ impl SplitSwapStrategyEncoder {
let selector = "swap(uint256,address,address,uint256,bool,bool,uint256,address,((address,uint160,uint48,uint48),address,uint256),bytes,bytes)".to_string();
Ok(Self { permit2: Permit2::new(signer_pk, chain)?, selector })
}
fn validate_swaps(&self, swaps: &[Swap]) -> Result<(), EncodingError> {
for swap in swaps {
if swap.split >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
"Split percentage must be less than 1 (100%), got {}",
swap.split
)));
}
}
let mut swaps_by_token: HashMap<Bytes, Vec<&Swap>> = HashMap::new();
for swap in swaps {
swaps_by_token
.entry(swap.token_in.clone())
.or_default()
.push(swap);
}
for (token, token_swaps) in swaps_by_token {
if token_swaps.is_empty() {
return Err(EncodingError::InvalidInput(format!(
"No swaps found for token {:?}",
token
)));
}
// Single swaps don't need remainder handling
if token_swaps.len() == 1 {
if token_swaps[0].split != 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Single swap must have 0% split for token {:?}",
token
)));
}
continue;
}
// Check if exactly one swap has 0% split and it's the last one
let mut found_zero_split = false;
for (i, swap) in token_swaps.iter().enumerate() {
match (swap.split == 0.0, i == token_swaps.len() - 1) {
(true, false) => {
return Err(EncodingError::InvalidInput(format!(
"The 0% split for token {:?} must be the last swap",
token
)))
}
(true, true) => found_zero_split = true,
(false, _) => (),
}
}
if !found_zero_split {
return Err(EncodingError::InvalidInput(format!(
"Token {:?} must have exactly one 0% split for remainder handling",
token
)));
}
// Sum non-zero splits and validate each is >0% and <100%
let mut total_percentage = 0.0;
for swap in token_swaps
.iter()
.take(token_swaps.len() - 1)
{
if swap.split <= 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Non-remainder splits must be >0% for token {:?}",
token
)));
}
total_percentage += swap.split;
}
// Total must be <100% to leave room for remainder
if total_percentage >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
"Total of non-remainder splits for token {:?} must be <100%, got {}%",
token,
total_percentage * 100.0
)));
}
}
Ok(())
}
fn validate_token_path_connectivity(
&self,
swaps: &[Swap],
given_token: &Bytes,
checked_token: &Bytes,
) -> Result<(), EncodingError> {
// Build directed graph of token flows
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
for swap in swaps {
graph
.entry(&swap.token_in)
.or_default()
.insert(&swap.token_out);
}
// BFS from given_token
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(given_token);
while let Some(token) = queue.pop_front() {
if !visited.insert(token) {
continue;
}
// Early success check
if token == checked_token && visited.len() == graph.len() + 1 {
return Ok(());
}
if let Some(next_tokens) = graph.get(token) {
for &next_token in next_tokens {
if !visited.contains(next_token) {
queue.push_back(next_token);
}
}
}
}
// If we get here, either checked_token wasn't reached or not all tokens were visited
if !visited.contains(checked_token) {
Err(EncodingError::InvalidInput(
"Checked token is not reachable through swap path".to_string(),
))
} else {
Err(EncodingError::InvalidInput(
"Some tokens are not connected to the main path".to_string(),
))
}
}
}
impl EVMStrategyEncoder for SplitSwapStrategyEncoder {}
impl StrategyEncoder for SplitSwapStrategyEncoder {
@@ -75,8 +214,8 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
solution: Solution,
router_address: Bytes,
) -> Result<(Vec<u8>, Bytes), EncodingError> {
validate_swaps(&solution.swaps)?;
validate_token_path_connectivity(
self.validate_swaps(&solution.swaps)?;
self.validate_token_path_connectivity(
&solution.swaps,
&solution.given_token,
&solution.checked_token,
@@ -213,144 +352,6 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
}
}
fn validate_swaps(swaps: &[Swap]) -> Result<(), EncodingError> {
for swap in swaps {
if swap.split >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
"Split percentage must be less than 1 (100%), got {}",
swap.split
)));
}
}
let mut swaps_by_token: HashMap<Bytes, Vec<&Swap>> = HashMap::new();
for swap in swaps {
swaps_by_token
.entry(swap.token_in.clone())
.or_default()
.push(swap);
}
for (token, token_swaps) in swaps_by_token {
if token_swaps.is_empty() {
return Err(EncodingError::InvalidInput(format!(
"No swaps found for token {:?}",
token
)));
}
// Single swaps don't need remainder handling
if token_swaps.len() == 1 {
if token_swaps[0].split != 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Single swap must have 0% split for token {:?}",
token
)));
}
continue;
}
// Check if exactly one swap has 0% split and it's the last one
let mut found_zero_split = false;
for (i, swap) in token_swaps.iter().enumerate() {
match (swap.split == 0.0, i == token_swaps.len() - 1) {
(true, false) => {
return Err(EncodingError::InvalidInput(format!(
"The 0% split for token {:?} must be the last swap",
token
)))
}
(true, true) => found_zero_split = true,
(false, _) => (),
}
}
if !found_zero_split {
return Err(EncodingError::InvalidInput(format!(
"Token {:?} must have exactly one 0% split for remainder handling",
token
)));
}
// Sum non-zero splits and validate each is >0% and <100%
let mut total_percentage = 0.0;
for swap in token_swaps
.iter()
.take(token_swaps.len() - 1)
{
if swap.split <= 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Non-remainder splits must be >0% for token {:?}",
token
)));
}
total_percentage += swap.split;
}
// Total must be <100% to leave room for remainder
if total_percentage >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
"Total of non-remainder splits for token {:?} must be <100%, got {}%",
token,
total_percentage * 100.0
)));
}
}
Ok(())
}
fn validate_token_path_connectivity(
swaps: &[Swap],
given_token: &Bytes,
checked_token: &Bytes,
) -> Result<(), EncodingError> {
// Build directed graph of token flows
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
for swap in swaps {
graph
.entry(&swap.token_in)
.or_default()
.insert(&swap.token_out);
}
// BFS from given_token
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(given_token);
while let Some(token) = queue.pop_front() {
if !visited.insert(token) {
continue;
}
// Early success check
if token == checked_token && visited.len() == graph.len() + 1 {
return Ok(());
}
if let Some(next_tokens) = graph.get(token) {
for &next_token in next_tokens {
if !visited.contains(next_token) {
queue.push_back(next_token);
}
}
}
}
// If we get here, either checked_token wasn't reached or not all tokens were visited
if !visited.contains(checked_token) {
Err(EncodingError::InvalidInput(
"Checked token is not reachable through swap path".to_string(),
))
} else {
Err(EncodingError::InvalidInput(
"Some tokens are not connected to the main path".to_string(),
))
}
}
/// This strategy encoder is used for solutions that are sent directly to the pool.
/// Only 1 solution with 1 swap is supported.
pub struct ExecutorStrategyEncoder {}
@@ -780,8 +781,15 @@ mod tests {
println!("{}", _hex_calldata);
}
fn get_mock_split_swap_strategy_encoder() -> SplitSwapStrategyEncoder {
let private_key =
"0x123456789abcdef123456789abcdef123456789abcdef123456789abcdef1234".to_string();
SplitSwapStrategyEncoder::new(private_key, Chain::Ethereum).unwrap()
}
#[test]
fn test_validate_token_path_connectivity_single_swap() {
let encoder = get_mock_split_swap_strategy_encoder();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swaps = vec![Swap {
@@ -794,12 +802,13 @@ mod tests {
token_out: dai.clone(),
split: 0f64,
}];
let result = validate_token_path_connectivity(&swaps, &weth, &dai);
let result = encoder.validate_token_path_connectivity(&swaps, &weth, &dai);
assert_eq!(result, Ok(()));
}
#[test]
fn test_validate_token_path_connectivity_multiple_swaps() {
let encoder = get_mock_split_swap_strategy_encoder();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
@@ -825,12 +834,13 @@ mod tests {
split: 0f64,
},
];
let result = validate_token_path_connectivity(&swaps, &weth, &usdc);
let result = encoder.validate_token_path_connectivity(&swaps, &weth, &usdc);
assert_eq!(result, Ok(()));
}
#[test]
fn test_validate_token_path_connectivity_multiple_swaps_failure() {
let encoder = get_mock_split_swap_strategy_encoder();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
@@ -860,7 +870,7 @@ mod tests {
split: 0.0,
},
];
let result = validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc);
let result = encoder.validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -877,7 +887,7 @@ mod tests {
token_out: dai.clone(),
split: 1.0,
}];
let result = validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc);
let result = encoder.validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -885,7 +895,7 @@ mod tests {
// Test case 3: Empty swaps
let empty_swaps: Vec<Swap> = vec![];
let result = validate_token_path_connectivity(&empty_swaps, &weth, &usdc);
let result = encoder.validate_token_path_connectivity(&empty_swaps, &weth, &usdc);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -894,6 +904,7 @@ mod tests {
#[test]
fn test_validate_swaps_single_swap() {
let encoder = get_mock_split_swap_strategy_encoder();
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
let swaps = vec![Swap {
@@ -906,7 +917,7 @@ mod tests {
token_out: dai.clone(),
split: 0f64,
}];
let result = validate_swaps(&swaps);
let result = encoder.validate_swaps(&swaps);
assert_eq!(result, Ok(()));
}
@@ -948,7 +959,10 @@ mod tests {
split: 0.0, // Remainder (20%)
},
];
assert!(validate_swaps(&valid_swaps).is_ok());
let encoder = get_mock_split_swap_strategy_encoder();
assert!(encoder
.validate_swaps(&valid_swaps)
.is_ok());
}
#[test]
@@ -979,8 +993,9 @@ mod tests {
split: 0.3,
},
];
let encoder = get_mock_split_swap_strategy_encoder();
assert!(matches!(
validate_swaps(&invalid_total_swaps),
encoder.validate_swaps(&invalid_total_swaps),
Err(EncodingError::InvalidInput(msg)) if msg.contains("must have exactly one 0% split")
));
@@ -1008,7 +1023,7 @@ mod tests {
},
];
assert!(matches!(
validate_swaps(&invalid_zero_position_swaps),
encoder.validate_swaps(&invalid_zero_position_swaps),
Err(EncodingError::InvalidInput(msg)) if msg.contains("must be the last swap")
));
@@ -1045,8 +1060,9 @@ mod tests {
split: 0.0,
},
];
let encoder = get_mock_split_swap_strategy_encoder();
assert!(matches!(
validate_swaps(&invalid_overflow_swaps),
encoder.validate_swaps(&invalid_overflow_swaps),
Err(EncodingError::InvalidInput(msg)) if msg.contains("must be <100%")
));
}