chore: Move get_receiver logic inside TransferOptimization

Added tests for all cases

--- don't change below this line ---
ENG-4446 Took 1 hour 8 minutes
This commit is contained in:
Diana Carvalho
2025-04-22 15:30:33 +01:00
parent cebacc68fe
commit dff4a345fc
2 changed files with 133 additions and 40 deletions

View File

@@ -8,7 +8,6 @@ use crate::encoding::{
errors::EncodingError, errors::EncodingError,
evm::{ evm::{
approvals::permit2::Permit2, approvals::permit2::Permit2,
constants::{CALLBACK_CONSTRAINED_PROTOCOLS, IN_TRANSFER_REQUIRED_PROTOCOLS},
group_swaps::group_swaps, group_swaps::group_swaps,
strategy_encoder::{ strategy_encoder::{
strategy_validators::{SequentialSwapValidator, SplitSwapValidator, SwapValidator}, strategy_validators::{SequentialSwapValidator, SplitSwapValidator, SwapValidator},
@@ -64,12 +63,13 @@ impl SingleSwapStrategyEncoder {
permit2, permit2,
selector, selector,
swap_encoder_registry, swap_encoder_registry,
router_address, router_address: router_address.clone(),
transfer_optimization: TransferOptimization::new( transfer_optimization: TransferOptimization::new(
chain.native_token()?, chain.native_token()?,
chain.wrapped_token()?, chain.wrapped_token()?,
permit2_is_active, permit2_is_active,
token_in_already_in_router, token_in_already_in_router,
router_address,
), ),
}) })
} }
@@ -246,7 +246,7 @@ impl SequentialSwapStrategyEncoder {
permit2, permit2,
selector, selector,
swap_encoder_registry, swap_encoder_registry,
router_address, router_address: router_address.clone(),
native_address: chain.native_token()?, native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?, wrapped_address: chain.wrapped_token()?,
sequential_swap_validator: SequentialSwapValidator, sequential_swap_validator: SequentialSwapValidator,
@@ -255,6 +255,7 @@ impl SequentialSwapStrategyEncoder {
chain.wrapped_token()?, chain.wrapped_token()?,
permit2_is_active, permit2_is_active,
token_in_already_in_router, token_in_already_in_router,
router_address,
), ),
}) })
} }
@@ -295,7 +296,7 @@ impl StrategyEncoder for SequentialSwapStrategyEncoder {
} }
let mut swaps = vec![]; let mut swaps = vec![];
let mut next_in_between_swap_optimization = true; let mut next_in_between_swap_optimization_allowed = true;
for (i, grouped_swap) in grouped_swaps.iter().enumerate() { for (i, grouped_swap) in grouped_swaps.iter().enumerate() {
let protocol = grouped_swap.protocol_system.clone(); let protocol = grouped_swap.protocol_system.clone();
let swap_encoder = self let swap_encoder = self
@@ -307,37 +308,19 @@ impl StrategyEncoder for SequentialSwapStrategyEncoder {
)) ))
})?; })?;
let in_between_swap_optimization = next_in_between_swap_optimization; let in_between_swap_optimization_allowed = next_in_between_swap_optimization_allowed;
let next_swap = grouped_swaps.get(i + 1); let next_swap = grouped_swaps.get(i + 1);
// if there is a next swap let (swap_receiver, next_swap_optimization) = self
let swap_receiver = if let Some(next) = next_swap { .transfer_optimization
// if the protocol of the next swap supports transfer in optimization .get_receiver(solution.receiver.clone(), next_swap)?;
if IN_TRANSFER_REQUIRED_PROTOCOLS.contains(&next.protocol_system.as_str()) { next_in_between_swap_optimization_allowed = next_swap_optimization;
// if the protocol does not allow for chained swaps, we can't optimize the
// receiver of this swap nor the transfer in of the next swap
if CALLBACK_CONSTRAINED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
next_in_between_swap_optimization = false;
self.router_address.clone()
} else {
Bytes::from_str(&next.swaps[0].component.id.clone()).map_err(|_| {
EncodingError::FatalError("Invalid component id".to_string())
})?
}
} else {
// the protocol of the next swap does not support transfer in optimization
self.router_address.clone()
}
} else {
solution.receiver.clone() // last swap - there is not next swap
};
let transfer_type = self let transfer_type = self
.transfer_optimization .transfer_optimization
.get_transfer_type( .get_transfer_type(
grouped_swap.clone(), grouped_swap.clone(),
solution.given_token.clone(), solution.given_token.clone(),
wrap, wrap,
in_between_swap_optimization, in_between_swap_optimization_allowed,
); );
let encoding_context = EncodingContext { let encoding_context = EncodingContext {
receiver: swap_receiver.clone(), receiver: swap_receiver.clone(),
@@ -463,12 +446,13 @@ impl SplitSwapStrategyEncoder {
native_address: chain.native_token()?, native_address: chain.native_token()?,
wrapped_address: chain.wrapped_token()?, wrapped_address: chain.wrapped_token()?,
split_swap_validator: SplitSwapValidator, split_swap_validator: SplitSwapValidator,
router_address, router_address: router_address.clone(),
transfer_optimization: TransferOptimization::new( transfer_optimization: TransferOptimization::new(
chain.native_token()?, chain.native_token()?,
chain.wrapped_token()?, chain.wrapped_token()?,
permit2_is_active, permit2_is_active,
token_in_already_in_router, token_in_already_in_router,
router_address,
), ),
}) })
} }

View File

@@ -1,7 +1,13 @@
use std::str::FromStr;
use tycho_common::Bytes; use tycho_common::Bytes;
use crate::encoding::{ use crate::encoding::{
evm::{constants::IN_TRANSFER_REQUIRED_PROTOCOLS, group_swaps::SwapGroup}, errors::EncodingError,
evm::{
constants::{CALLBACK_CONSTRAINED_PROTOCOLS, IN_TRANSFER_REQUIRED_PROTOCOLS},
group_swaps::SwapGroup,
},
models::TransferType, models::TransferType,
}; };
@@ -12,6 +18,7 @@ pub struct TransferOptimization {
wrapped_token: Bytes, wrapped_token: Bytes,
permit2: bool, permit2: bool,
token_in_already_in_router: bool, token_in_already_in_router: bool,
router_address: Bytes,
} }
impl TransferOptimization { impl TransferOptimization {
@@ -20,9 +27,17 @@ impl TransferOptimization {
wrapped_token: Bytes, wrapped_token: Bytes,
permit2: bool, permit2: bool,
token_in_already_in_router: bool, token_in_already_in_router: bool,
router_address: Bytes,
) -> Self { ) -> Self {
TransferOptimization { native_token, wrapped_token, permit2, token_in_already_in_router } TransferOptimization {
native_token,
wrapped_token,
permit2,
token_in_already_in_router,
router_address,
}
} }
/// Returns the transfer method that should be used for the given swap and solution. /// Returns the transfer method that should be used for the given swap and solution.
pub fn get_transfer_type( pub fn get_transfer_type(
&self, &self,
@@ -75,13 +90,50 @@ impl TransferOptimization {
TransferType::TransferToProtocol TransferType::TransferToProtocol
} }
} }
// Returns the optimized receiver of the swap. This is used to chain swaps together and avoid
// unnecessary token transfers.
// Returns the receiver address and a boolean indicating whether the receiver is optimized (this
// is necessary for the next swap transfer type decision).
pub fn get_receiver(
&self,
solution_receiver: Bytes,
next_swap: Option<&SwapGroup>,
) -> Result<(Bytes, bool), EncodingError> {
if let Some(next) = next_swap {
// if the protocol of the next swap supports transfer in optimization
if IN_TRANSFER_REQUIRED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
// if the protocol does not allow for chained swaps, we can't optimize the
// receiver of this swap nor the transfer in of the next swap
if CALLBACK_CONSTRAINED_PROTOCOLS.contains(&next.protocol_system.as_str()) {
Ok((self.router_address.clone(), false))
} else {
Ok((
Bytes::from_str(&next.swaps[0].component.id.clone()).map_err(|_| {
EncodingError::FatalError("Invalid component id".to_string())
})?,
true,
))
}
} else {
// the protocol of the next swap does not support transfer in optimization
Ok((self.router_address.clone(), false))
}
} else {
// last swap - there is no next swap
Ok((solution_receiver, false))
}
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use alloy_primitives::hex; use alloy_primitives::hex;
use rstest::rstest;
use tycho_common::models::protocol::ProtocolComponent;
use super::*; use super::*;
use crate::encoding::models::Swap;
fn weth() -> Bytes { fn weth() -> Bytes {
Bytes::from(hex!("c02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").to_vec()) Bytes::from(hex!("c02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").to_vec())
@@ -99,6 +151,10 @@ mod tests {
Bytes::from(hex!("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").to_vec()) Bytes::from(hex!("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").to_vec())
} }
fn router_address() -> Bytes {
Bytes::from("0x5615deb798bb3e4dfa0139dfa1b3d433cc23b72f")
}
#[test] #[test]
fn test_first_swap_transfer_from_permit2() { fn test_first_swap_transfer_from_permit2() {
// The swap token is the same as the given token, which is not the native token // The swap token is the same as the given token, which is not the native token
@@ -109,7 +165,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), true, false); let optimization = TransferOptimization::new(eth(), weth(), true, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::TransferPermit2ToProtocol); assert_eq!(transfer_method, TransferType::TransferPermit2ToProtocol);
} }
@@ -124,7 +180,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, false); let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::TransferFromToProtocol); assert_eq!(transfer_method, TransferType::TransferFromToProtocol);
} }
@@ -140,7 +196,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, false); let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), eth(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), eth(), false, false);
assert_eq!(transfer_method, TransferType::None); assert_eq!(transfer_method, TransferType::None);
} }
@@ -156,7 +212,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, false); let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), eth(), true, false); let transfer_method = optimization.get_transfer_type(swap.clone(), eth(), true, false);
assert_eq!(transfer_method, TransferType::TransferToProtocol); assert_eq!(transfer_method, TransferType::TransferToProtocol);
} }
@@ -172,7 +228,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, false); let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::TransferToProtocol); assert_eq!(transfer_method, TransferType::TransferToProtocol);
} }
@@ -188,7 +244,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, false); let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, false);
assert_eq!(transfer_method, TransferType::None); assert_eq!(transfer_method, TransferType::None);
} }
@@ -204,7 +260,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, false); let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, true); let transfer_method = optimization.get_transfer_type(swap.clone(), weth(), false, true);
assert_eq!(transfer_method, TransferType::None); assert_eq!(transfer_method, TransferType::None);
} }
@@ -220,7 +276,7 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, true); let optimization = TransferOptimization::new(eth(), weth(), false, true, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), usdc(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), usdc(), false, false);
assert_eq!(transfer_method, TransferType::TransferToProtocol); assert_eq!(transfer_method, TransferType::TransferToProtocol);
} }
@@ -236,8 +292,61 @@ mod tests {
split: 0f64, split: 0f64,
swaps: vec![], swaps: vec![],
}; };
let optimization = TransferOptimization::new(eth(), weth(), false, true); let optimization = TransferOptimization::new(eth(), weth(), false, true, router_address());
let transfer_method = optimization.get_transfer_type(swap.clone(), usdc(), false, false); let transfer_method = optimization.get_transfer_type(swap.clone(), usdc(), false, false);
assert_eq!(transfer_method, TransferType::None); assert_eq!(transfer_method, TransferType::None);
} }
fn receiver() -> Bytes {
Bytes::from("0xcd09f75E2BF2A4d11F3AB23f1389FcC1621c0cc2")
}
fn component_id() -> Bytes {
Bytes::from("0xA478c2975Ab1Ea89e8196811F51A7B7Ade33eB11")
}
#[rstest]
// there is no next swap -> receiver is the solution receiver
#[case(None, receiver(), false)]
// protocol of next swap supports transfer in optimization
#[case(Some("uniswap_v2"), component_id(), true)]
// protocol of next swap supports transfer in optimization but is callback constrained
#[case(Some("uniswap_v3"), router_address(), false)]
// protocol of next swap does not support transfer in optimization
#[case(Some("vm:curve"), router_address(), false)]
fn test_get_receiver(
#[case] protocol: Option<&str>,
#[case] expected_receiver: Bytes,
#[case] expected_optimization: bool,
) {
let optimization = TransferOptimization::new(eth(), weth(), false, false, router_address());
let next_swap = if protocol.is_none() {
None
} else {
Some(SwapGroup {
protocol_system: protocol.unwrap().to_string(),
token_in: usdc(),
token_out: dai(),
split: 0f64,
swaps: vec![Swap {
component: ProtocolComponent {
protocol_system: protocol.unwrap().to_string(),
id: component_id().to_string(),
..Default::default()
},
token_in: usdc(),
token_out: dai(),
split: 0f64,
}],
})
};
let result = optimization.get_receiver(receiver(), next_swap.as_ref());
assert!(result.is_ok());
let (actual_receiver, optimization_flag) = result.unwrap();
assert_eq!(actual_receiver, expected_receiver);
assert_eq!(optimization_flag, expected_optimization);
}
} }