diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index 5c8e03d4..88730dd6 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "libp2p-kad" description = "Kademlia protocol for libp2p" -version = "0.2.0" +version = "0.2.1" authors = ["Parity Technologies "] license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" diff --git a/protocols/kad/src/kbucket.rs b/protocols/kad/src/kbucket.rs index 191eac2b..e70dffda 100644 --- a/protocols/kad/src/kbucket.rs +++ b/protocols/kad/src/kbucket.rs @@ -109,7 +109,7 @@ impl KBucketsPeerId for Multihash { let my_hash = U512::from(self.digest()); let other_hash = U512::from(other.digest()); let xor = my_hash ^ other_hash; - xor.leading_zeros() + 512 - xor.leading_zeros() } #[inline] @@ -142,7 +142,7 @@ where // Returns `None` if out of range, which happens if `id` is the same as the local peer id. #[inline] fn bucket_num(&self, id: &Id) -> Option { - (Id::max_distance() - 1).checked_sub(self.my_id.distance_with(id) as usize) + (self.my_id.distance_with(id) as usize).checked_sub(1) } /// Returns an iterator to all the buckets of this table. @@ -323,32 +323,15 @@ impl<'a, Id: 'a, Val: 'a> Bucket<'a, Id, Val> { mod tests { extern crate rand; use self::rand::random; - use kbucket::{KBucketsTable, UpdateOutcome, MAX_NODES_PER_BUCKET}; - use multihash::Multihash; + use kbucket::{KBucketsPeerId, KBucketsTable, UpdateOutcome, MAX_NODES_PER_BUCKET}; + use multihash::{Multihash, Hash}; use std::thread; use std::time::Duration; #[test] fn basic_closest() { - let my_id = { - let mut bytes = vec![random(); 34]; - bytes[0] = 18; - bytes[1] = 32; - Multihash::from_bytes(bytes.clone()).expect(&format!( - "creating `my_id` Multihash from bytes {:#?} failed", - bytes - )) - }; - - let other_id = { - let mut bytes = vec![random(); 34]; - bytes[0] = 18; - bytes[1] = 32; - Multihash::from_bytes(bytes.clone()).expect(&format!( - "creating `other_id` Multihash from bytes {:#?} failed", - bytes - )) - }; + let my_id = Multihash::random(Hash::SHA2256); + let other_id = Multihash::random(Hash::SHA2256); let mut table = KBucketsTable::new(my_id, Duration::from_secs(5)); let _ = table.update(other_id.clone(), ()); @@ -360,12 +343,7 @@ mod tests { #[test] fn update_local_id_fails() { - let my_id = { - let mut bytes = vec![random(); 34]; - bytes[0] = 18; - bytes[1] = 32; - Multihash::from_bytes(bytes).unwrap() - }; + let my_id = Multihash::random(Hash::SHA2256); let mut table = KBucketsTable::new(my_id.clone(), Duration::from_secs(5)); match table.update(my_id, ()) { @@ -376,12 +354,7 @@ mod tests { #[test] fn update_time_last_refresh() { - let my_id = { - let mut bytes = vec![random(); 34]; - bytes[0] = 18; - bytes[1] = 32; - Multihash::from_bytes(bytes).unwrap() - }; + let my_id = Multihash::random(Hash::SHA2256); // Generate some other IDs varying by just one bit. let other_ids = (0..random::() % 20) @@ -414,12 +387,7 @@ mod tests { #[test] fn full_kbucket() { - let my_id = { - let mut bytes = vec![random(); 34]; - bytes[0] = 18; - bytes[1] = 32; - Multihash::from_bytes(bytes).unwrap() - }; + let my_id = Multihash::random(Hash::SHA2256); assert!(MAX_NODES_PER_BUCKET <= 251); // Test doesn't work otherwise. let mut fill_ids = (0..MAX_NODES_PER_BUCKET + 3) @@ -468,4 +436,18 @@ mod tests { UpdateOutcome::NeedPing(second_node) ); } + + #[test] + fn self_distance_zero() { + let a = Multihash::random(Hash::SHA2256); + assert_eq!(a.distance_with(&a), 0); + } + + #[test] + fn distance_correct_order() { + let a = Multihash::random(Hash::SHA2256); + let b = Multihash::random(Hash::SHA2256); + assert!(a.distance_with(&a) < b.distance_with(&a)); + assert!(a.distance_with(&b) > b.distance_with(&b)); + } } diff --git a/protocols/kad/src/query.rs b/protocols/kad/src/query.rs index 71d8f60f..50d4f8c4 100644 --- a/protocols/kad/src/query.rs +++ b/protocols/kad/src/query.rs @@ -106,17 +106,22 @@ impl QueryState { /// /// You should call `poll()` this function returns in order to know what to do. pub fn new(config: QueryConfig>) -> QueryState { + let mut closest_peers: SmallVec<[_; 32]> = config + .known_closest_peers + .into_iter() + .map(|peer_id| (peer_id, QueryPeerState::NotContacted)) + .take(config.num_results) + .collect(); + let target = config.target; + closest_peers.sort_by_key(|e| target.as_hash().distance_with(e.0.as_ref())); + closest_peers.dedup_by(|a, b| a.0 == b.0); + QueryState { - target: config.target, + target, stage: QueryStage::Iterating { no_closer_in_a_row: 0, }, - closest_peers: config - .known_closest_peers - .into_iter() - .map(|peer_id| (peer_id, QueryPeerState::NotContacted)) - .take(config.num_results) - .collect(), + closest_peers, parallelism: config.parallelism, num_results: config.num_results, rpc_timeout: config.rpc_timeout, @@ -160,28 +165,46 @@ impl QueryState { for elem_to_add in closer_peers { let target = &self.target; - let insert_pos = self.closest_peers.iter().position(|(id, _)| { - let a = target.as_hash().distance_with(id.as_ref()); - let b = target.as_hash().distance_with(elem_to_add.as_ref()); - a >= b + let elem_to_add_distance = target.as_hash().distance_with(elem_to_add.as_ref()); + let insert_pos_start = self.closest_peers.iter().position(|(id, _)| { + target.as_hash().distance_with(id.as_ref()) >= elem_to_add_distance }); - if let Some(insert_pos) = insert_pos { + if let Some(insert_pos_start) = insert_pos_start { + // We need to insert the element between `insert_pos_start` and + // `insert_pos_start + insert_pos_size`. + let insert_pos_size = self.closest_peers.iter() + .skip(insert_pos_start) + .position(|(id, _)| { + target.as_hash().distance_with(id.as_ref()) > elem_to_add_distance + }); + // Make sure we don't insert duplicates. - if self.closest_peers[insert_pos].0 != elem_to_add { - if insert_pos == 0 { + let duplicate = if let Some(insert_pos_size) = insert_pos_size { + self.closest_peers.iter().skip(insert_pos_start).take(insert_pos_size).any(|e| e.0 == elem_to_add) + } else { + self.closest_peers.iter().skip(insert_pos_start).any(|e| e.0 == elem_to_add) + }; + + if !duplicate { + if insert_pos_start == 0 { *no_closer_in_a_row = 0; } + debug_assert!(self.closest_peers.iter().all(|e| e.0 != elem_to_add)); self.closest_peers - .insert(insert_pos, (elem_to_add, QueryPeerState::NotContacted)); + .insert(insert_pos_start, (elem_to_add, QueryPeerState::NotContacted)); } } else if self.closest_peers.len() < self.num_results { + debug_assert!(self.closest_peers.iter().all(|e| e.0 != elem_to_add)); self.closest_peers .push((elem_to_add, QueryPeerState::NotContacted)); } } } + // Check for duplicates in `closest_peers`. + debug_assert!(self.closest_peers.windows(2).all(|w| w[0].0 != w[1].0)); + // Handle if `no_closer_in_a_row` is too high. let freeze = if let QueryStage::Iterating { no_closer_in_a_row } = self.stage { no_closer_in_a_row >= self.parallelism