Merge pull request #103 from propeller-heads/encoding/tnl/ENG-4300-encorce-min-amount

feat: enforce checked amount when encoding to router
This commit is contained in:
Tamara
2025-03-06 09:07:46 -05:00
committed by GitHub
2 changed files with 95 additions and 8 deletions

View File

@@ -107,6 +107,8 @@ impl EVMStrategyEncoder for SplitSwapStrategyEncoder {}
impl StrategyEncoder for SplitSwapStrategyEncoder {
fn encode_strategy(&self, solution: Solution) -> Result<(Vec<u8>, Bytes), EncodingError> {
self.split_swap_validator
.validate_solution_min_amounts(&solution)?;
self.split_swap_validator
.validate_split_percentages(&solution.swaps)?;
self.split_swap_validator
@@ -555,12 +557,6 @@ mod tests {
}
#[rstest]
#[case::no_check_no_slippage(
None,
None,
None,
U256::from_str("0").unwrap(),
)]
#[case::with_check_no_slippage(
None,
None,
@@ -764,7 +760,7 @@ mod tests {
given_amount: BigUint::from_str("3_000_000000000000000000").unwrap(),
checked_token: eth(),
expected_amount: Some(BigUint::from_str("1_000000000000000000").unwrap()),
checked_amount: None,
checked_amount: Some(BigUint::from_str("1_000000000000000000").unwrap()),
sender: Bytes::from_str("0xcd09f75E2BF2A4d11F3AB23f1389FcC1621c0cc2").unwrap(),
receiver: Bytes::from_str("0xcd09f75E2BF2A4d11F3AB23f1389FcC1621c0cc2").unwrap(),
router_address: Bytes::from_str("0x3Ede3eCa2a72B3aeCC820E955B36f38437D01395").unwrap(),

View File

@@ -4,7 +4,7 @@ use tycho_core::Bytes;
use crate::encoding::{
errors::EncodingError,
models::{NativeAction, Swap},
models::{NativeAction, Solution, Swap},
};
/// Validates whether a sequence of split swaps represents a valid solution.
@@ -90,6 +90,19 @@ impl SplitSwapValidator {
Ok(())
}
/// Raises an error if the solution does not have checked amount set or slippage with checked
/// amount set.
pub fn validate_solution_min_amounts(&self, solution: &Solution) -> Result<(), EncodingError> {
if solution.checked_amount.is_none() &&
(solution.slippage.is_none() || solution.expected_amount.is_none())
{
return Err(EncodingError::InvalidInput(
"Checked amount or slippage with expected amount must be provided".to_string(),
))
}
Ok(())
}
/// Raises an error if swaps do not represent a valid path from the given token to the checked
/// token.
///
@@ -178,6 +191,8 @@ impl SplitSwapValidator {
mod tests {
use std::str::FromStr;
use num_bigint::BigUint;
use rstest::rstest;
use tycho_core::{models::protocol::ProtocolComponent, Bytes};
use super::*;
@@ -549,4 +564,80 @@ mod tests {
);
assert_eq!(result, Ok(()));
}
#[rstest]
#[case::slippage_with_expected_amount_set(
Some(0.01),
Some(BigUint::from(1000u32)),
None,
Ok(())
)]
#[case::min_amount_set(
None,
None,
Some(BigUint::from(1000u32)),
Ok(())
)]
#[case::slippage_with_min_amount_set(
Some(0.01),
Some(BigUint::from(1000u32)),
Some(BigUint::from(1000u32)),
Ok(())
)]
#[case::slippage_without_expected_amount_set(
Some(0.01),
None,
None,
Err(
EncodingError::InvalidInput(
"Checked amount or slippage with expected amount must be provided".to_string()
)
)
)]
#[case::none_set(
None,
None,
None,
Err(
EncodingError::InvalidInput(
"Checked amount or slippage with expected amount must be provided".to_string()
)
)
)]
fn test_validate_min_amount_passed(
#[case] slippage: Option<f64>,
#[case] expected_amount: Option<BigUint>,
#[case] min_amount: Option<BigUint>,
#[case] expected_result: Result<(), EncodingError>,
) {
let weth = Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap();
let usdc = Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap();
let validator = SplitSwapValidator;
let swap = Swap {
component: ProtocolComponent {
id: "0xA478c2975Ab1Ea89e8196811F51A7B7Ade33eB11".to_string(),
protocol_system: "uniswap_v2".to_string(),
..Default::default()
},
token_in: weth.clone(),
token_out: usdc.clone(),
split: 0f64,
};
let solution = Solution {
exact_out: false,
given_token: weth,
checked_token: usdc,
slippage,
checked_amount: min_amount,
expected_amount,
swaps: vec![swap],
native_action: Some(NativeAction::Wrap),
..Default::default()
};
let result = validator.validate_solution_min_amounts(&solution);
assert_eq!(result, expected_result);
}
}