diff --git a/protocol-testing/Cargo.lock b/protocol-testing/Cargo.lock index a6cb008..d62b3f8 100644 --- a/protocol-testing/Cargo.lock +++ b/protocol-testing/Cargo.lock @@ -7823,7 +7823,8 @@ dependencies = [ [[package]] name = "tycho-simulation" -version = "0.155.2" +version = "0.156.0" +source = "git+https://github.com/propeller-heads/tycho-simulation.git?tag=0.156.0#1983a787440e8ae757626d808a6e619baffc52f2" dependencies = [ "alloy", "async-stream", diff --git a/protocol-testing/src/rpc.rs b/protocol-testing/src/rpc.rs index 7ad7a33..7a87cc0 100644 --- a/protocol-testing/src/rpc.rs +++ b/protocol-testing/src/rpc.rs @@ -69,13 +69,17 @@ impl RPCProvider { } } - // TODO: Implement - // async fn get_block_header(&self, _block_number: u64) { - // let provider = ProviderBuilder::new().on_http(self.url); - // let block_id: BlockId = BlockId::from(block_number); - // - // let block = provider.get_block(block_id) - // } + pub async fn get_block_header(&self, block_number: u64) -> miette::Result { + let provider = ProviderBuilder::new().connect_http(self.url.clone()); + let block_id: BlockId = BlockId::from(block_number); + + provider + .get_block(block_id) + .await + .into_diagnostic() + .wrap_err("Failed to fetch block header") + .and_then(|block_opt| block_opt.ok_or_else(|| miette::miette!("Block not found"))) + } } #[cfg(test)] @@ -122,4 +126,23 @@ mod tests { assert_eq!(balance, U256::from(717250938432_u64)); } + + #[tokio::test] + async fn get_block_header_test() { + let eth_rpc_url = env::var("RPC_URL").expect("Missing RPC_URL in environment"); + + let rpc_provider = RPCProvider::new(eth_rpc_url); + let block_number = 21998530; + + let block_header = rpc_provider + .get_block_header(block_number) + .await + .unwrap(); + + // Verify that we got a block with the correct number + assert_eq!(block_number, block_header.header.number); + + // Verify that the timestamp is non-zero + assert!(block_header.header.timestamp > 0); + } } diff --git a/protocol-testing/src/test_runner.rs b/protocol-testing/src/test_runner.rs index dfc3ca8..d7ed12a 100644 --- a/protocol-testing/src/test_runner.rs +++ b/protocol-testing/src/test_runner.rs @@ -336,22 +336,53 @@ fn validate_state( decoder_context, ); + // Filter out components that have skip_simulation = true (match Python behavior) + let simulation_component_ids: std::collections::HashSet = expected_components + .iter() + .filter(|c| !c.skip_simulation) + .map(|c| c.base.id.clone()) + .collect(); + + info!("Components to simulate: {}", simulation_component_ids.len()); + for id in &simulation_component_ids { + info!(" Simulating component: {}", id); + } + + if simulation_component_ids.is_empty() { + info!("No components to simulate, skipping simulation validation"); + return Ok(()); + } + // Mock a stream message, with only a Snapshot and no deltas let mut states: HashMap = HashMap::new(); - for (id, component) in components_by_id { - let component_id = &id.clone(); + for (id, component) in &components_by_id { + let component_id = id; + + // Only include components that should be simulated + if !simulation_component_ids.contains(component_id) { + continue; + } + let state = protocol_states_by_id .get(component_id) - .wrap_err("Failed to get state for component")? + .wrap_err("Failed to get state for component" + )? .clone(); - let component_with_state = - ComponentWithState { state, component, component_tvl: None, entrypoints: vec![] }; // TODO + + let component_with_state = ComponentWithState { + state, + component: component.clone(), + component_tvl: None, + entrypoints: vec![], + }; // TODO states.insert(component_id.clone(), component_with_state); } + // Convert vm_storages to a HashMap - match Python behavior exactly let vm_storage: HashMap = vm_storages .into_iter() .map(|x| (x.address.clone(), x)) .collect(); + let snapshot = Snapshot { states, vm_storage }; let bytes = [0u8; 32]; @@ -414,6 +445,8 @@ fn validate_state( // We then retrieve the amount out for 0.1%, 1% and 10%. let percentages = [0.001, 0.01, 0.1]; // Get limits for this token pair + // TODO do this again, but reverse the order of the tokens to get the opposite swap + // direction let (max_input, max_output) = state .get_limits(tokens[0].address.clone(), tokens[1].address.clone()) .into_diagnostic()