chore: Move get_receiver logic inside TransferOptimization

Added tests for all cases

--- don't change below this line ---
ENG-4446 Took 1 hour 8 minutes
This commit is contained in:
Diana Carvalho
2025-04-22 15:30:33 +01:00
parent cebacc68fe
commit dff4a345fc
2 changed files with 133 additions and 40 deletions

View File

@@ -8,7 +8,6 @@ use crate::encoding::{
errors::EncodingError,
evm::{
approvals::permit2::Permit2,
constants::{CALLBACK_CONSTRAINED_PROTOCOLS, IN_TRANSFER_REQUIRED_PROTOCOLS},
group_swaps::group_swaps,
strategy_encoder::{
strategy_validators::{SequentialSwapValidator, SplitSwapValidator, SwapValidator},
@@ -64,12 +63,13 @@ impl SingleSwapStrategyEncoder {
permit2,
selector,
swap_encoder_registry,
router_address,
router_address: router_address.clone(),
transfer_optimization: TransferOptimization::new(
chain.native_token()?,
chain.wrapped_token()?,
permit2_is_active,
token_in_already_in_router,
router_address,
),
})
}
@@ -246,7 +246,7 @@ impl SequentialSwapStrategyEncoder {
permit2,
selector,
swap_encoder_registry,
router_address,
router_address: router_address.clone(),
native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?,
sequential_swap_validator: SequentialSwapValidator,
@@ -255,6 +255,7 @@ impl SequentialSwapStrategyEncoder {
chain.wrapped_token()?,
permit2_is_active,
token_in_already_in_router,
router_address,
),
})
}
@@ -295,7 +296,7 @@ impl StrategyEncoder for SequentialSwapStrategyEncoder {
}
let mut swaps = vec![];
let mut next_in_between_swap_optimization = true;
let mut next_in_between_swap_optimization_allowed = true;
for (i, grouped_swap) in grouped_swaps.iter().enumerate() {
let protocol = grouped_swap.protocol_system.clone();
let swap_encoder = self
@@ -307,37 +308,19 @@ impl StrategyEncoder for SequentialSwapStrategyEncoder {
))
})?;
let in_between_swap_optimization = next_in_between_swap_optimization;
let in_between_swap_optimization_allowed = next_in_between_swap_optimization_allowed;
let next_swap = grouped_swaps.get(i + 1);
// if there is a next swap
let swap_receiver = if let Some(next) = next_swap {
// if the protocol of the next swap supports transfer in optimization
if IN_TRANSFER_REQUIRED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
// if the protocol does not allow for chained swaps, we can't optimize the
// receiver of this swap nor the transfer in of the next swap
if CALLBACK_CONSTRAINED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
next_in_between_swap_optimization = false;
self.router_address.clone()
} else {
Bytes::from_str(&next.swaps[0].component.id.clone()).map_err(|_| {
EncodingError::FatalError("Invalid component id".to_string())
})?
}
} else {
// the protocol of the next swap does not support transfer in optimization
self.router_address.clone()
}
} else {
solution.receiver.clone() // last swap - there is not next swap
};
let (swap_receiver, next_swap_optimization) = self
.transfer_optimization
.get_receiver(solution.receiver.clone(), next_swap)?;
next_in_between_swap_optimization_allowed = next_swap_optimization;
let transfer_type = self
.transfer_optimization
.get_transfer_type(
grouped_swap.clone(),
solution.given_token.clone(),
wrap,
in_between_swap_optimization,
in_between_swap_optimization_allowed,
);
let encoding_context = EncodingContext {
receiver: swap_receiver.clone(),
@@ -463,12 +446,13 @@ impl SplitSwapStrategyEncoder {
native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?,
split_swap_validator: SplitSwapValidator,
router_address,
router_address: router_address.clone(),
transfer_optimization: TransferOptimization::new(
chain.native_token()?,
chain.wrapped_token()?,
permit2_is_active,
token_in_already_in_router,
router_address,
),
})
}

View File

@@ -1,7 +1,13 @@
use std::str::FromStr;
use tycho_common::Bytes;
use crate::encoding::{
evm::{constants::IN_TRANSFER_REQUIRED_PROTOCOLS, group_swaps::SwapGroup},
errors::EncodingError,
evm::{
constants::{CALLBACK_CONSTRAINED_PROTOCOLS, IN_TRANSFER_REQUIRED_PROTOCOLS},
group_swaps::SwapGroup,
},
models::TransferType,
};
@@ -12,6 +18,7 @@ pub struct TransferOptimization {
wrapped_token: Bytes,
permit2: bool,
token_in_already_in_router: bool,
router_address: Bytes,
}
impl TransferOptimization {
@@ -20,9 +27,17 @@ impl TransferOptimization {
wrapped_token: Bytes,
permit2: bool,
token_in_already_in_router: bool,
router_address: Bytes,
) -> Self {
TransferOptimization { native_token, wrapped_token, permit2, token_in_already_in_router }
TransferOptimization {
native_token,
wrapped_token,
permit2,
token_in_already_in_router,
router_address,
}
}
/// Returns the transfer method that should be used for the given swap and solution.
pub fn get_transfer_type(
&self,
@@ -75,13 +90,50 @@ impl TransferOptimization {
TransferType::TransferToProtocol
}
}
// Returns the optimized receiver of the swap. This is used to chain swaps together and avoid
// unnecessary token transfers.
// Returns the receiver address and a boolean indicating whether the receiver is optimized (this
// is necessary for the next swap transfer type decision).
pub fn get_receiver(
&self,
solution_receiver: Bytes,
next_swap: Option<&SwapGroup>,
) -> Result<(Bytes, bool), EncodingError> {
if let Some(next) = next_swap {
// if the protocol of the next swap supports transfer in optimization
if IN_TRANSFER_REQUIRED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
// if the protocol does not allow for chained swaps, we can't optimize the
// receiver of this swap nor the transfer in of the next swap
if CALLBACK_CONSTRAINED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
Ok((self.router_address.clone(), false))
} else {
Ok((
Bytes::from_str(&next.swaps[0].component.id.clone()).map_err(|_| {
EncodingError::FatalError("Invalid component id".to_string())
})?,
true,
))
}
} else {
// the protocol of the next swap does not support transfer in optimization
Ok((self.router_address.clone(), false))
}
} else {
// last swap - there is no next swap
Ok((solution_receiver, false))
}
}
}
#[cfg(test)]
mod tests {
use alloy_primitives::hex;
use rstest::rstest;
use tycho_common::models::protocol::ProtocolComponent;
use super::*;
use crate::encoding::models::Swap;
fn weth() -> Bytes {
Bytes::from(hex!("c02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").to_vec())
@@ -99,6 +151,10 @@ mod tests {
Bytes::from(hex!("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").to_vec())
}
fn router_address() -> Bytes {
Bytes::from("0x5615deb798bb3e4dfa0139dfa1b3d433cc23b72f")
}
#[test]
fn test_first_swap_transfer_from_permit2() {
// The swap token is the same as the given token, which is not the native token
@@ -109,7 +165,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), true, false);
let optimization = TransferOptimization::new(eth(), weth(), true, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::TransferPermit2ToProtocol);
}
@@ -124,7 +180,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, false);
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::TransferFromToProtocol);
}
@@ -140,7 +196,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, false);
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), eth(), false, false);
assert_eq!(transfer_method, TransferType::None);
}
@@ -156,7 +212,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, false);
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), eth(), true, false);
assert_eq!(transfer_method, TransferType::TransferToProtocol);
}
@@ -172,7 +228,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, false);
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::TransferToProtocol);
}
@@ -188,7 +244,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, false);
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::None);
}
@@ -204,7 +260,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, false);
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, true);
assert_eq!(transfer_method, TransferType::None);
}
@@ -220,7 +276,7 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, true);
let optimization = TransferOptimization::new(eth(), weth(), false, true, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), usdc(), false, false);
assert_eq!(transfer_method, TransferType::TransferToProtocol);
}
@@ -236,8 +292,61 @@ mod tests {
split: 0f64,
swaps: vec![],
};
let optimization = TransferOptimization::new(eth(), weth(), false, true);
let optimization = TransferOptimization::new(eth(), weth(), false, true, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), usdc(), false, false);
assert_eq!(transfer_method, TransferType::None);
}
fn receiver() -> Bytes {
Bytes::from("0xcd09f75E2BF2A4d11F3AB23f1389FcC1621c0cc2")
}
fn component_id() -> Bytes {
Bytes::from("0xA478c2975Ab1Ea89e8196811F51A7B7Ade33eB11")
}
#[rstest]
// there is no next swap -> receiver is the solution receiver
#[case(None, receiver(), false)]
// protocol of next swap supports transfer in optimization
#[case(Some("uniswap_v2"), component_id(), true)]
// protocol of next swap supports transfer in optimization but is callback constrained
#[case(Some("uniswap_v3"), router_address(), false)]
// protocol of next swap does not support transfer in optimization
#[case(Some("vm:curve"), router_address(), false)]
fn test_get_receiver(
#[case] protocol: Option<&str>,
#[case] expected_receiver: Bytes,
#[case] expected_optimization: bool,
) {
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let next_swap = if protocol.is_none() {
None
} else {
Some(SwapGroup {
protocol_system: protocol.unwrap().to_string(),
token_in: usdc(),
token_out: dai(),
split: 0f64,
swaps: vec![Swap {
component: ProtocolComponent {
protocol_system: protocol.unwrap().to_string(),
id: component_id().to_string(),
..Default::default()
},
token_in: usdc(),
token_out: dai(),
split: 0f64,
}],
})
};
let result = optimization.get_receiver(receiver(), next_swap.as_ref());
assert!(result.is_ok());
let (actual_receiver, optimization_flag) = result.unwrap();
assert_eq!(actual_receiver, expected_receiver);
assert_eq!(optimization_flag, expected_optimization);
}
}