From c13f03354b8ec179845df825afbddb6665decaba Mon Sep 17 00:00:00 2001 From: stuart nelson Date: Mon, 27 Sep 2021 18:22:10 +0200 Subject: [PATCH] protocols/kad: Check local store on get_providers (#2221) --- protocols/kad/CHANGELOG.md | 3 ++ protocols/kad/src/behaviour.rs | 9 +++++- protocols/kad/src/behaviour/test.rs | 48 +++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/protocols/kad/CHANGELOG.md b/protocols/kad/CHANGELOG.md index e344f9c2..d1557c34 100644 --- a/protocols/kad/CHANGELOG.md +++ b/protocols/kad/CHANGELOG.md @@ -8,7 +8,10 @@ - Introduce `KademliaStoreInserts` option, which allows to filter records (see [PR 2163]). +- Check local store when calling `Kademlia::get_providers` (see [PR 2221]). + [PR 2163]: https://github.com/libp2p/rust-libp2p/pull/2163 +[PR 2221]: https://github.com/libp2p/rust-libp2p/pull/2163 # 0.31.0 [2021-07-12] diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index 74b5616f..80236eb5 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -912,9 +912,16 @@ where /// The result of this operation is delivered in a /// reported via [`KademliaEvent::OutboundQueryCompleted{QueryResult::GetProviders}`]. pub fn get_providers(&mut self, key: record::Key) -> QueryId { + let providers = self + .store + .providers(&key) + .into_iter() + .filter(|p| !p.is_expired(Instant::now())) + .map(|p| p.provider) + .collect(); let info = QueryInfo::GetProviders { key: key.clone(), - providers: HashSet::new(), + providers, }; let target = kbucket::Key::new(key); let peers = self.kbuckets.closest_keys(&target); diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index fc855b04..a39ff5af 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -1317,3 +1317,51 @@ fn network_behaviour_inject_address_change() { kademlia.addresses_of_peer(&remote_peer_id), ); } + +#[test] +fn get_providers() { + fn prop(key: record::Key) { + let (_, mut single_swarm) = build_node(); + single_swarm + .behaviour_mut() + .start_providing(key.clone()) + .expect("could not provide"); + + block_on(async { + match single_swarm.next().await.unwrap() { + SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + result: QueryResult::StartProviding(Ok(_)), + .. + }) => {} + SwarmEvent::Behaviour(e) => panic!("Unexpected event: {:?}", e), + _ => {} + } + }); + + let query_id = single_swarm.behaviour_mut().get_providers(key.clone()); + + block_on(async { + match single_swarm.next().await.unwrap() { + SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + id, + result: + QueryResult::GetProviders(Ok(GetProvidersOk { + key: found_key, + providers, + .. + })), + .. + }) if id == query_id => { + assert_eq!(key, found_key); + assert_eq!( + single_swarm.local_peer_id(), + providers.iter().next().unwrap() + ); + } + SwarmEvent::Behaviour(e) => panic!("Unexpected event: {:?}", e), + _ => {} + } + }); + } + QuickCheck::new().tests(10).quickcheck(prop as fn(_)) +}