fix: use native action to validate path

This commit is contained in:
royvardhan
2025-02-05 00:14:09 +05:30
parent 2f0013a934
commit c787f5e722

View File

@@ -95,8 +95,8 @@ impl SplitSwapStrategyEncoder {
continue;
}
// Check if exactly one swap has 0% split and it's the last one
let mut found_zero_split = false;
let mut total_percentage = 0.0;
for (i, swap) in token_swaps.iter().enumerate() {
match (swap.split == 0.0, i == token_swaps.len() - 1) {
(true, false) => {
@@ -106,7 +106,15 @@ impl SplitSwapStrategyEncoder {
)))
}
(true, true) => found_zero_split = true,
(false, _) => (),
(false, _) => {
if swap.split <= 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Non-remainder splits must be >0% for token {:?}",
token
)));
}
total_percentage += swap.split;
}
}
}
@@ -117,21 +125,6 @@ impl SplitSwapStrategyEncoder {
)));
}
// Sum non-zero splits and validate each is >0% and <100%
let mut total_percentage = 0.0;
for swap in token_swaps
.iter()
.take(token_swaps.len() - 1)
{
if swap.split <= 0.0 {
return Err(EncodingError::InvalidInput(format!(
"Non-remainder splits must be >0% for token {:?}",
token
)));
}
total_percentage += swap.split;
}
// Total must be <100% to leave room for remainder
if total_percentage >= 1.0 {
return Err(EncodingError::InvalidInput(format!(
@@ -150,13 +143,26 @@ impl SplitSwapStrategyEncoder {
swaps: &[Swap],
given_token: &Bytes,
checked_token: &Bytes,
native_action: &Option<NativeAction>,
) -> Result<(), EncodingError> {
// Special case: If given_token is ETH or checked_token is ETH, treat it as WETH for path
// validation
let given_token = if *given_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { given_token };
// Convert ETH to WETH only if there's a corresponding wrap/unwrap action
let given_token = if *given_token == *NATIVE_ADDRESS {
match native_action {
Some(NativeAction::Wrap) => &WETH_ADDRESS,
_ => given_token,
}
} else {
given_token
};
let checked_token =
if *checked_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { checked_token };
let checked_token = if *checked_token == *NATIVE_ADDRESS {
match native_action {
Some(NativeAction::Unwrap) => &WETH_ADDRESS,
_ => checked_token,
}
} else {
checked_token
};
// Build directed graph of token flows
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
@@ -211,7 +217,12 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
router_address: Bytes,
) -> Result<(Vec<u8>, Bytes), EncodingError> {
self.validate_split_percentages(&solution.swaps)?;
self.validate_swap_path(&solution.swaps, &solution.given_token, &solution.checked_token)?;
self.validate_swap_path(
&solution.swaps,
&solution.given_token,
&solution.checked_token,
&solution.native_action,
)?;
let (permit, signature) = self.permit2.get_permit(
&router_address,
&solution.sender,
@@ -794,7 +805,7 @@ mod tests {
token_out: dai.clone(),
split: 0f64,
}];
let result = encoder.validate_swap_path(&swaps, &weth, &dai);
let result = encoder.validate_swap_path(&swaps, &weth, &dai, &None);
assert_eq!(result, Ok(()));
}
@@ -826,7 +837,7 @@ mod tests {
split: 0f64,
},
];
let result = encoder.validate_swap_path(&swaps, &weth, &usdc);
let result = encoder.validate_swap_path(&swaps, &weth, &usdc, &None);
assert_eq!(result, Ok(()));
}
@@ -861,7 +872,7 @@ mod tests {
split: 0.0,
},
];
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc);
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc, &None);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -885,7 +896,7 @@ mod tests {
token_out: dai.clone(),
split: 1.0,
}];
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc);
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc, &None);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -899,7 +910,7 @@ mod tests {
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let empty_swaps: Vec<Swap> = vec![];
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc);
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc, &None);
assert!(matches!(
result,
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
@@ -1100,7 +1111,7 @@ mod tests {
split: 0f64,
}];
let result = encoder.validate_swap_path(&swaps, &eth, &usdc);
let result = encoder.validate_swap_path(&swaps, &eth, &usdc, &Some(NativeAction::Wrap));
assert_eq!(result, Ok(()));
}
@@ -1123,7 +1134,7 @@ mod tests {
split: 0f64,
}];
let result = encoder.validate_swap_path(&swaps, &usdc, &eth);
let result = encoder.validate_swap_path(&swaps, &usdc, &eth, &Some(NativeAction::Unwrap));
assert_eq!(result, Ok(()));
}
}