fix: use native action to validate path
This commit is contained in:
@@ -95,8 +95,8 @@ impl SplitSwapStrategyEncoder {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if exactly one swap has 0% split and it's the last one
|
|
||||||
let mut found_zero_split = false;
|
let mut found_zero_split = false;
|
||||||
|
let mut total_percentage = 0.0;
|
||||||
for (i, swap) in token_swaps.iter().enumerate() {
|
for (i, swap) in token_swaps.iter().enumerate() {
|
||||||
match (swap.split == 0.0, i == token_swaps.len() - 1) {
|
match (swap.split == 0.0, i == token_swaps.len() - 1) {
|
||||||
(true, false) => {
|
(true, false) => {
|
||||||
@@ -106,7 +106,15 @@ impl SplitSwapStrategyEncoder {
|
|||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
(true, true) => found_zero_split = true,
|
(true, true) => found_zero_split = true,
|
||||||
(false, _) => (),
|
(false, _) => {
|
||||||
|
if swap.split <= 0.0 {
|
||||||
|
return Err(EncodingError::InvalidInput(format!(
|
||||||
|
"Non-remainder splits must be >0% for token {:?}",
|
||||||
|
token
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
total_percentage += swap.split;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,21 +125,6 @@ impl SplitSwapStrategyEncoder {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sum non-zero splits and validate each is >0% and <100%
|
|
||||||
let mut total_percentage = 0.0;
|
|
||||||
for swap in token_swaps
|
|
||||||
.iter()
|
|
||||||
.take(token_swaps.len() - 1)
|
|
||||||
{
|
|
||||||
if swap.split <= 0.0 {
|
|
||||||
return Err(EncodingError::InvalidInput(format!(
|
|
||||||
"Non-remainder splits must be >0% for token {:?}",
|
|
||||||
token
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
total_percentage += swap.split;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Total must be <100% to leave room for remainder
|
// Total must be <100% to leave room for remainder
|
||||||
if total_percentage >= 1.0 {
|
if total_percentage >= 1.0 {
|
||||||
return Err(EncodingError::InvalidInput(format!(
|
return Err(EncodingError::InvalidInput(format!(
|
||||||
@@ -150,13 +143,26 @@ impl SplitSwapStrategyEncoder {
|
|||||||
swaps: &[Swap],
|
swaps: &[Swap],
|
||||||
given_token: &Bytes,
|
given_token: &Bytes,
|
||||||
checked_token: &Bytes,
|
checked_token: &Bytes,
|
||||||
|
native_action: &Option<NativeAction>,
|
||||||
) -> Result<(), EncodingError> {
|
) -> Result<(), EncodingError> {
|
||||||
// Special case: If given_token is ETH or checked_token is ETH, treat it as WETH for path
|
// Convert ETH to WETH only if there's a corresponding wrap/unwrap action
|
||||||
// validation
|
let given_token = if *given_token == *NATIVE_ADDRESS {
|
||||||
let given_token = if *given_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { given_token };
|
match native_action {
|
||||||
|
Some(NativeAction::Wrap) => &WETH_ADDRESS,
|
||||||
|
_ => given_token,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
given_token
|
||||||
|
};
|
||||||
|
|
||||||
let checked_token =
|
let checked_token = if *checked_token == *NATIVE_ADDRESS {
|
||||||
if *checked_token == *NATIVE_ADDRESS { &WETH_ADDRESS } else { checked_token };
|
match native_action {
|
||||||
|
Some(NativeAction::Unwrap) => &WETH_ADDRESS,
|
||||||
|
_ => checked_token,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
checked_token
|
||||||
|
};
|
||||||
|
|
||||||
// Build directed graph of token flows
|
// Build directed graph of token flows
|
||||||
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
|
let mut graph: HashMap<&Bytes, HashSet<&Bytes>> = HashMap::new();
|
||||||
@@ -211,7 +217,12 @@ impl StrategyEncoder for SplitSwapStrategyEncoder {
|
|||||||
router_address: Bytes,
|
router_address: Bytes,
|
||||||
) -> Result<(Vec<u8>, Bytes), EncodingError> {
|
) -> Result<(Vec<u8>, Bytes), EncodingError> {
|
||||||
self.validate_split_percentages(&solution.swaps)?;
|
self.validate_split_percentages(&solution.swaps)?;
|
||||||
self.validate_swap_path(&solution.swaps, &solution.given_token, &solution.checked_token)?;
|
self.validate_swap_path(
|
||||||
|
&solution.swaps,
|
||||||
|
&solution.given_token,
|
||||||
|
&solution.checked_token,
|
||||||
|
&solution.native_action,
|
||||||
|
)?;
|
||||||
let (permit, signature) = self.permit2.get_permit(
|
let (permit, signature) = self.permit2.get_permit(
|
||||||
&router_address,
|
&router_address,
|
||||||
&solution.sender,
|
&solution.sender,
|
||||||
@@ -794,7 +805,7 @@ mod tests {
|
|||||||
token_out: dai.clone(),
|
token_out: dai.clone(),
|
||||||
split: 0f64,
|
split: 0f64,
|
||||||
}];
|
}];
|
||||||
let result = encoder.validate_swap_path(&swaps, &weth, &dai);
|
let result = encoder.validate_swap_path(&swaps, &weth, &dai, &None);
|
||||||
assert_eq!(result, Ok(()));
|
assert_eq!(result, Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -826,7 +837,7 @@ mod tests {
|
|||||||
split: 0f64,
|
split: 0f64,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let result = encoder.validate_swap_path(&swaps, &weth, &usdc);
|
let result = encoder.validate_swap_path(&swaps, &weth, &usdc, &None);
|
||||||
assert_eq!(result, Ok(()));
|
assert_eq!(result, Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -861,7 +872,7 @@ mod tests {
|
|||||||
split: 0.0,
|
split: 0.0,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc);
|
let result = encoder.validate_swap_path(&disconnected_swaps, &weth, &usdc, &None);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||||
@@ -885,7 +896,7 @@ mod tests {
|
|||||||
token_out: dai.clone(),
|
token_out: dai.clone(),
|
||||||
split: 1.0,
|
split: 1.0,
|
||||||
}];
|
}];
|
||||||
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc);
|
let result = encoder.validate_swap_path(&unreachable_swaps, &weth, &usdc, &None);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||||
@@ -899,7 +910,7 @@ mod tests {
|
|||||||
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
|
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
|
||||||
|
|
||||||
let empty_swaps: Vec<Swap> = vec![];
|
let empty_swaps: Vec<Swap> = vec![];
|
||||||
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc);
|
let result = encoder.validate_swap_path(&empty_swaps, &weth, &usdc, &None);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
Err(EncodingError::InvalidInput(msg)) if msg.contains("not reachable through swap path")
|
||||||
@@ -1100,7 +1111,7 @@ mod tests {
|
|||||||
split: 0f64,
|
split: 0f64,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = encoder.validate_swap_path(&swaps, ð, &usdc);
|
let result = encoder.validate_swap_path(&swaps, ð, &usdc, &Some(NativeAction::Wrap));
|
||||||
assert_eq!(result, Ok(()));
|
assert_eq!(result, Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1123,7 +1134,7 @@ mod tests {
|
|||||||
split: 0f64,
|
split: 0f64,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let result = encoder.validate_swap_path(&swaps, &usdc, ð);
|
let result = encoder.validate_swap_path(&swaps, &usdc, ð, &Some(NativeAction::Unwrap));
|
||||||
assert_eq!(result, Ok(()));
|
assert_eq!(result, Ok(()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user