fix: use native action to validate path
This commit is contained in:
@@ -95,8 +95,8 @@ impl SplitSwapStrategyEncoder {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if exactly one swap has 0% split and it's the last one
|
||||
let mut found_zero_split = false;
|
||||
let mut total_percentage = 0.0;
|
||||
for (i, swap) in token_swaps.iter().enumerate() {
|
||||
match (swap.split == 0.0, i == token_swaps.len() - 1) {
|
||||
(true, false) => {
|
||||
@@ -106,7 +106,15 @@ impl SplitSwapStrategyEncoder {
|
||||
)))
|
||||
}
|
||||
(true, true) => found_zero_split = true,
|
||||
(false, _) => (),
|
||||
(false, _) => {
|
||||
if swap.split <= 0.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Non-remainder splits must be >0% for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
total_percentage += swap.split;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,21 +125,6 @@ impl SplitSwapStrategyEncoder {
|
||||
)));
|
||||
}
|
||||
|
||||
// Sum non-zero splits and validate each is >0% and <100%
|
||||
let mut total_percentage = 0.0;
|
||||
for swap in token_swaps
|
||||
.iter()
|
||||
.take(token_swaps.len() - 1)
|
||||
{
|
||||
if swap.split <= 0.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
"Non-remainder splits must be >0% for token {:?}",
|
||||
token
|
||||
)));
|
||||
}
|
||||
total_percentage += swap.split;
|
||||
}
|
||||
|
||||
// Total must be <100% to leave room for remainder
|
||||
if total_percentage >= 1.0 {
|
||||
return Err(EncodingError::InvalidInput(format!(
|
||||
@@ -150,13 +143,26 @@ impl SplitSwapStrategyEncoder {
|
||||
swaps: &[Swap],
|
||||
given_token: &Bytes,
|
||||
checked_token: &Bytes,
|
||||
native_action: &Option<NativeAction>,
|
||||
) -> Result<(), EncodingError> {
|
||||
// Special case: If given_token is ETH or checked_token is ETH, treat it as WETH for path
|
||||
// validation
|
||||
let given_token = if *given_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { given_token };
|
||||
// Convert ETH to WETH only if there's a corresponding wrap/unwrap action
|
||||
let given_token = if *given_token == *NATIVE_ADDRESS {
|
||||
match native_action {
|
||||
Some(NativeAction::Wrap) => &WETH_ADDRESS,
|
||||
_ => given_token,
|
||||
}
|
||||
} else {
|
||||
given_token
|
||||
};
|
||||
|
||||
let checked_token =
|
||||
if *checked_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { checked_token };
|
||||
let checked_token = if *checked_token == *NATIVE_ADDRESS {
|
||||
match native_action {
|
||||
Some(NativeAction::Unwrap) => &WETH_ADDRESS,
|
||||
_ => checked_token,
|
||||
}
|
||||
} else {
|
||||
checked_token
|
||||
};
|
||||
|
||||
// Build directed graph of token flows
|
||||
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
|
||||
@@ -211,7 +217,12 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
|
||||
router_address: Bytes,
|
||||
) -> Result<(Vec<u8>, Bytes), EncodingError> {
|
||||
self.validate_split_percentages(&solution.swaps)?;
|
||||
self.validate_swap_path(&solution.swaps, &solution.given_token, &solution.checked_token)?;
|
||||
self.validate_swap_path(
|
||||
&solution.swaps,
|
||||
&solution.given_token,
|
||||
&solution.checked_token,
|
||||
&solution.native_action,
|
||||
)?;
|
||||
let (permit, signature) = self.permit2.get_permit(
|
||||
&router_address,
|
||||
&solution.sender,
|
||||
@@ -794,7 +805,7 @@ mod tests {
|
||||
token_out: dai.clone(),
|
||||
split: 0f64,
|
||||
}];
|
||||
let result = encoder.validate_swap_path(&swaps, &weth, &dai);
|
||||
let result = encoder.validate_swap_path(&swaps, &weth, &dai, &None);
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
|
||||
@@ -826,7 +837,7 @@ mod tests {
|
||||
split: 0f64,
|
||||
},
|
||||
];
|
||||
let result = encoder.validate_swap_path(&swaps, &weth, &usdc);
|
||||
let result = encoder.validate_swap_path(&swaps, &weth, &usdc, &None);
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
|
||||
@@ -861,7 +872,7 @@ mod tests {
|
||||
split: 0.0,
|
||||
},
|
||||
];
|
||||
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc);
|
||||
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc, &None);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||
@@ -885,7 +896,7 @@ mod tests {
|
||||
token_out: dai.clone(),
|
||||
split: 1.0,
|
||||
}];
|
||||
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc);
|
||||
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc, &None);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||
@@ -899,7 +910,7 @@ mod tests {
|
||||
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
|
||||
|
||||
let empty_swaps: Vec<Swap> = vec![];
|
||||
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc);
|
||||
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc, &None);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||
@@ -1100,7 +1111,7 @@ mod tests {
|
||||
split: 0f64,
|
||||
}];
|
||||
|
||||
let result = encoder.validate_swap_path(&swaps, ð, &usdc);
|
||||
let result = encoder.validate_swap_path(&swaps, ð, &usdc, &Some(NativeAction::Wrap));
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
|
||||
@@ -1123,7 +1134,7 @@ mod tests {
|
||||
split: 0f64,
|
||||
}];
|
||||
|
||||
let result = encoder.validate_swap_path(&swaps, &usdc, ð);
|
||||
let result = encoder.validate_swap_path(&swaps, &usdc, ð, &Some(NativeAction::Unwrap));
|
||||
assert_eq!(result, Ok(()));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user