chore: move validation methods inside impl
This commit is contained in:
@@ -67,6 +67,145 @@ impl SplitSwapStrategyEncoder {
|
||||
let selector = "swap(uint256,address,address,uint256,bool,bool,uint256,address,((address,uint160,uint48,uint48),address,uint256),bytes,bytes)".to_string();
|
||||
Ok(Self { permit2: Permit2::new(signer_pk, chain)?, selector })
|
||||
}
|
||||
|
||||
fn validate_swaps(&self, swaps: &[Swap]) -> Result<(), EncodingError> {
|
||||
for swap in swaps {
|
||||
if swap.split >= 1.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Split percentage must be less than 1 (100%), got {}",
|
||||
swap.split
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut swaps_by_token: HashMap<Bytes, Vec<&Swap>> = HashMap::new();
|
||||
|
||||
for swap in swaps {
|
||||
swaps_by_token
|
||||
.entry(swap.token_in.clone())
|
||||
.or_default()
|
||||
.push(swap);
|
||||
}
|
||||
|
||||
for (token, token_swaps) in swaps_by_token {
|
||||
if token_swaps.is_empty() {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"No swaps found for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
|
||||
// Single swaps don't need remainder handling
|
||||
if token_swaps.len() == 1 {
|
||||
if token_swaps[0].split != 0.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Single swap must have 0% split for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if exactly one swap has 0% split and it's the last one
|
||||
let mut found_zero_split = false;
|
||||
for (i, swap) in token_swaps.iter().enumerate() {
|
||||
match (swap.split == 0.0, i == token_swaps.len() - 1) {
|
||||
(true, false) => {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"The 0% split for token {:?} must be the last swap",
|
||||
token
|
||||
)))
|
||||
}
|
||||
(true, true) => found_zero_split = true,
|
||||
(false, _) => (),
|
||||
}
|
||||
}
|
||||
|
||||
if !found_zero_split {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Token {:?} must have exactly one 0% split for remainder handling",
|
||||
token
|
||||
)));
|
||||
}
|
||||
|
||||
// Sum non-zero splits and validate each is >0% and <100%
|
||||
let mut total_percentage = 0.0;
|
||||
for swap in token_swaps
|
||||
.iter()
|
||||
.take(token_swaps.len() - 1)
|
||||
{
|
||||
if swap.split <= 0.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Non-remainder splits must be >0% for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
total_percentage += swap.split;
|
||||
}
|
||||
|
||||
// Total must be <100% to leave room for remainder
|
||||
if total_percentage >= 1.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Total of non-remainder splits for token {:?} must be <100%, got {}%",
|
||||
token,
|
||||
total_percentage * 100.0
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_token_path_connectivity(
|
||||
&self,
|
||||
swaps: &[Swap],
|
||||
given_token: &Bytes,
|
||||
checked_token: &Bytes,
|
||||
) -> Result<(), EncodingError> {
|
||||
// Build directed graph of token flows
|
||||
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
|
||||
for swap in swaps {
|
||||
graph
|
||||
.entry(&swap.token_in)
|
||||
.or_default()
|
||||
.insert(&swap.token_out);
|
||||
}
|
||||
|
||||
// BFS from given_token
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back(given_token);
|
||||
|
||||
while let Some(token) = queue.pop_front() {
|
||||
if !visited.insert(token) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Early success check
|
||||
if token == checked_token && visited.len() == graph.len() + 1 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(next_tokens) = graph.get(token) {
|
||||
for &next_token in next_tokens {
|
||||
if !visited.contains(next_token) {
|
||||
queue.push_back(next_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, either checked_token wasn't reached or not all tokens were visited
|
||||
if !visited.contains(checked_token) {
|
||||
Err(EncodingError::InvalidInput(
|
||||
"Checked token is not reachable through swap path".to_string(),
|
||||
))
|
||||
} else {
|
||||
Err(EncodingError::InvalidInput(
|
||||
"Some tokens are not connected to the main path".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
impl EVMStrategyEncoder for SplitSwapStrategyEncoder {}
|
||||
impl StrategyEncoder for SplitSwapStrategyEncoder {
|
||||
@@ -75,8 +214,8 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
|
||||
solution: Solution,
|
||||
router_address: Bytes,
|
||||
) -> Result<(Vec<u8>, Bytes), EncodingError> {
|
||||
validate_swaps(&solution.swaps)?;
|
||||
validate_token_path_connectivity(
|
||||
self.validate_swaps(&solution.swaps)?;
|
||||
self.validate_token_path_connectivity(
|
||||
&solution.swaps,
|
||||
&solution.given_token,
|
||||
&solution.checked_token,
|
||||
@@ -213,144 +352,6 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_swaps(swaps: &[Swap]) -> Result<(), EncodingError> {
|
||||
for swap in swaps {
|
||||
if swap.split >= 1.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Split percentage must be less than 1 (100%), got {}",
|
||||
swap.split
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut swaps_by_token: HashMap<Bytes, Vec<&Swap>> = HashMap::new();
|
||||
|
||||
for swap in swaps {
|
||||
swaps_by_token
|
||||
.entry(swap.token_in.clone())
|
||||
.or_default()
|
||||
.push(swap);
|
||||
}
|
||||
|
||||
for (token, token_swaps) in swaps_by_token {
|
||||
if token_swaps.is_empty() {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"No swaps found for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
|
||||
// Single swaps don't need remainder handling
|
||||
if token_swaps.len() == 1 {
|
||||
if token_swaps[0].split != 0.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Single swap must have 0% split for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if exactly one swap has 0% split and it's the last one
|
||||
let mut found_zero_split = false;
|
||||
for (i, swap) in token_swaps.iter().enumerate() {
|
||||
match (swap.split == 0.0, i == token_swaps.len() - 1) {
|
||||
(true, false) => {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"The 0% split for token {:?} must be the last swap",
|
||||
token
|
||||
)))
|
||||
}
|
||||
(true, true) => found_zero_split = true,
|
||||
(false, _) => (),
|
||||
}
|
||||
}
|
||||
|
||||
if !found_zero_split {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Token {:?} must have exactly one 0% split for remainder handling",
|
||||
token
|
||||
)));
|
||||
}
|
||||
|
||||
// Sum non-zero splits and validate each is >0% and <100%
|
||||
let mut total_percentage = 0.0;
|
||||
for swap in token_swaps
|
||||
.iter()
|
||||
.take(token_swaps.len() - 1)
|
||||
{
|
||||
if swap.split <= 0.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Non-remainder splits must be >0% for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
total_percentage += swap.split;
|
||||
}
|
||||
|
||||
// Total must be <100% to leave room for remainder
|
||||
if total_percentage >= 1.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Total of non-remainder splits for token {:?} must be <100%, got {}%",
|
||||
token,
|
||||
total_percentage * 100.0
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_token_path_connectivity(
|
||||
swaps: &[Swap],
|
||||
given_token: &Bytes,
|
||||
checked_token: &Bytes,
|
||||
) -> Result<(), EncodingError> {
|
||||
// Build directed graph of token flows
|
||||
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
|
||||
for swap in swaps {
|
||||
graph
|
||||
.entry(&swap.token_in)
|
||||
.or_default()
|
||||
.insert(&swap.token_out);
|
||||
}
|
||||
|
||||
// BFS from given_token
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back(given_token);
|
||||
|
||||
while let Some(token) = queue.pop_front() {
|
||||
if !visited.insert(token) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Early success check
|
||||
if token == checked_token && visited.len() == graph.len() + 1 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(next_tokens) = graph.get(token) {
|
||||
for &next_token in next_tokens {
|
||||
if !visited.contains(next_token) {
|
||||
queue.push_back(next_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, either checked_token wasn't reached or not all tokens were visited
|
||||
if !visited.contains(checked_token) {
|
||||
Err(EncodingError::InvalidInput(
|
||||
"Checked token is not reachable through swap path".to_string(),
|
||||
))
|
||||
} else {
|
||||
Err(EncodingError::InvalidInput(
|
||||
"Some tokens are not connected to the main path".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// This strategy encoder is used for solutions that are sent directly to the pool.
|
||||
/// Only 1 solution with 1 swap is supported.
|
||||
pub struct ExecutorStrategyEncoder {}
|
||||
@@ -780,8 +781,15 @@ mod tests {
|
||||
println!("{}", _hex_calldata);
|
||||
}
|
||||
|
||||
fn get_mock_split_swap_strategy_encoder() -> SplitSwapStrategyEncoder {
|
||||
let private_key =
|
||||
"0x123456789abcdef123456789abcdef123456789abcdef123456789abcdef1234".to_string();
|
||||
SplitSwapStrategyEncoder::new(private_key, Chain::Ethereum).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_path_connectivity_single_swap() {
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
|
||||
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
|
||||
let swaps = vec![Swap {
|
||||
@@ -794,12 +802,13 @@ mod tests {
|
||||
token_out: dai.clone(),
|
||||
split: 0f64,
|
||||
}];
|
||||
let result = validate_token_path_connectivity(&swaps, &weth, &dai);
|
||||
let result = encoder.validate_token_path_connectivity(&swaps, &weth, &dai);
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_path_connectivity_multiple_swaps() {
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
|
||||
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
|
||||
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
|
||||
@@ -825,12 +834,13 @@ mod tests {
|
||||
split: 0f64,
|
||||
},
|
||||
];
|
||||
let result = validate_token_path_connectivity(&swaps, &weth, &usdc);
|
||||
let result = encoder.validate_token_path_connectivity(&swaps, &weth, &usdc);
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_path_connectivity_multiple_swaps_failure() {
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
|
||||
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
|
||||
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
|
||||
@@ -860,7 +870,7 @@ mod tests {
|
||||
split: 0.0,
|
||||
},
|
||||
];
|
||||
let result = validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc);
|
||||
let result = encoder.validate_token_path_connectivity(&disconnected_swaps, &weth, &usdc);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||
@@ -877,7 +887,7 @@ mod tests {
|
||||
token_out: dai.clone(),
|
||||
split: 1.0,
|
||||
}];
|
||||
let result = validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc);
|
||||
let result = encoder.validate_token_path_connectivity(&unreachable_swaps, &weth, &usdc);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||
@@ -885,7 +895,7 @@ mod tests {
|
||||
|
||||
// Test case 3: Empty swaps
|
||||
let empty_swaps: Vec<Swap> = vec![];
|
||||
let result = validate_token_path_connectivity(&empty_swaps, &weth, &usdc);
|
||||
let result = encoder.validate_token_path_connectivity(&empty_swaps, &weth, &usdc);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||
@@ -894,6 +904,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_swaps_single_swap() {
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
|
||||
let dai = Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap();
|
||||
let swaps = vec![Swap {
|
||||
@@ -906,7 +917,7 @@ mod tests {
|
||||
token_out: dai.clone(),
|
||||
split: 0f64,
|
||||
}];
|
||||
let result = validate_swaps(&swaps);
|
||||
let result = encoder.validate_swaps(&swaps);
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
|
||||
@@ -948,7 +959,10 @@ mod tests {
|
||||
split: 0.0, // Remainder (20%)
|
||||
},
|
||||
];
|
||||
assert!(validate_swaps(&valid_swaps).is_ok());
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
assert!(encoder
|
||||
.validate_swaps(&valid_swaps)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -979,8 +993,9 @@ mod tests {
|
||||
split: 0.3,
|
||||
},
|
||||
];
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
assert!(matches!(
|
||||
validate_swaps(&invalid_total_swaps),
|
||||
encoder.validate_swaps(&invalid_total_swaps),
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("must have exactly one 0% split")
|
||||
));
|
||||
|
||||
@@ -1008,7 +1023,7 @@ mod tests {
|
||||
},
|
||||
];
|
||||
assert!(matches!(
|
||||
validate_swaps(&invalid_zero_position_swaps),
|
||||
encoder.validate_swaps(&invalid_zero_position_swaps),
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("must be the last swap")
|
||||
));
|
||||
|
||||
@@ -1045,8 +1060,9 @@ mod tests {
|
||||
split: 0.0,
|
||||
},
|
||||
];
|
||||
let encoder = get_mock_split_swap_strategy_encoder();
|
||||
assert!(matches!(
|
||||
validate_swaps(&invalid_overflow_swaps),
|
||||
encoder.validate_swaps(&invalid_overflow_swaps),
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("must be <100%")
|
||||
));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user