refactor(swarm): express dial logic linearly (#3253)

Previously, the logic within `Swarm::dial` involved fairly convoluted `match` expressions. This patch refactors this function to use new utility functions introduced on `DialOpts` to handle one concern at a time.

This has the advantage that we are covering slightly more cases now. Because we are parsing the `PeerId` only once at the top, checks like banning will now also act on dials that specify the `PeerId` as part of the `/p2p` protocol.
This commit is contained in:
Thomas Eizinger
2022-12-23 03:44:58 +11:00
committed by GitHub
parent 1765ae0395
commit d5f4acc6ed
2 changed files with 156 additions and 123 deletions

View File

@ -20,6 +20,8 @@
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use libp2p_core::connection::Endpoint; use libp2p_core::connection::Endpoint;
use libp2p_core::multiaddr::Protocol;
use libp2p_core::multihash::Multihash;
use libp2p_core::{Multiaddr, PeerId}; use libp2p_core::{Multiaddr, PeerId};
use std::num::NonZeroU8; use std::num::NonZeroU8;
@ -79,6 +81,104 @@ impl DialOpts {
DialOpts(Opts::WithoutPeerIdWithAddress(_)) => None, DialOpts(Opts::WithoutPeerIdWithAddress(_)) => None,
} }
} }
/// Retrieves the [`PeerId`] from the [`DialOpts`] if specified or otherwise tries to parse it
/// from the multihash in the `/p2p` part of the address, if present.
///
/// Note: A [`Multiaddr`] with something else other than a [`PeerId`] within the `/p2p` protocol is invalid as per specification.
/// Unfortunately, we are not making good use of the type system here.
/// Really, this function should be merged with [`DialOpts::get_peer_id`] above.
/// If it weren't for the parsing error, the function signatures would be the same.
///
/// See <https://github.com/multiformats/rust-multiaddr/issues/73>.
pub(crate) fn get_or_parse_peer_id(&self) -> Result<Option<PeerId>, Multihash> {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { peer_id, .. })) => Ok(Some(*peer_id)),
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
peer_id, ..
})) => Ok(Some(*peer_id)),
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
address, ..
})) => {
let peer_id = address
.iter()
.last()
.and_then(|p| {
if let Protocol::P2p(ma) = p {
Some(PeerId::try_from(ma))
} else {
None
}
})
.transpose()?;
Ok(peer_id)
}
}
}
pub(crate) fn get_addresses(&self) -> Vec<Multiaddr> {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { .. })) => vec![],
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
addresses, ..
})) => addresses.clone(),
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
address, ..
})) => vec![address.clone()],
}
}
pub(crate) fn extend_addresses_through_behaviour(&self) -> bool {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { .. })) => true,
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
extend_addresses_through_behaviour,
..
})) => *extend_addresses_through_behaviour,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => true,
}
}
pub(crate) fn peer_condition(&self) -> PeerCondition {
match self {
DialOpts(
Opts::WithPeerId(WithPeerId { condition, .. })
| Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses { condition, .. }),
) => *condition,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => {
PeerCondition::Always
}
}
}
pub(crate) fn dial_concurrency_override(&self) -> Option<NonZeroU8> {
match self {
DialOpts(Opts::WithPeerId(WithPeerId {
dial_concurrency_factor_override,
..
})) => *dial_concurrency_factor_override,
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
dial_concurrency_factor_override,
..
})) => *dial_concurrency_factor_override,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => None,
}
}
pub(crate) fn role_override(&self) -> Endpoint {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { role_override, .. })) => *role_override,
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
role_override,
..
})) => *role_override,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
role_override,
..
})) => *role_override,
}
}
} }
impl From<Multiaddr> for DialOpts { impl From<Multiaddr> for DialOpts {

View File

@ -122,7 +122,6 @@ pub use registry::{AddAddressResult, AddressRecord, AddressScore};
use connection::pool::{EstablishedConnection, Pool, PoolConfig, PoolEvent}; use connection::pool::{EstablishedConnection, Pool, PoolConfig, PoolEvent};
use connection::IncomingInfo; use connection::IncomingInfo;
use dial_opts::{DialOpts, PeerCondition}; use dial_opts::{DialOpts, PeerCondition};
use either::Either;
use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream}; use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream};
use libp2p_core::connection::ConnectionId; use libp2p_core::connection::ConnectionId;
use libp2p_core::muxing::SubstreamBox; use libp2p_core::muxing::SubstreamBox;
@ -138,7 +137,6 @@ use libp2p_core::{
use registry::{AddressIntoIter, Addresses}; use registry::{AddressIntoIter, Addresses};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::iter;
use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize}; use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize};
use std::{ use std::{
convert::TryFrom, convert::TryFrom,
@ -507,42 +505,32 @@ where
fn dial_with_handler( fn dial_with_handler(
&mut self, &mut self,
swarm_dial_opts: DialOpts, dial_opts: DialOpts,
handler: <TBehaviour as NetworkBehaviour>::ConnectionHandler, handler: <TBehaviour as NetworkBehaviour>::ConnectionHandler,
) -> Result<(), DialError> { ) -> Result<(), DialError> {
let (peer_id, addresses, dial_concurrency_factor_override, role_override) = let peer_id = dial_opts
match swarm_dial_opts.0 { .get_or_parse_peer_id()
// Dial a known peer. .map_err(DialError::InvalidPeerId)?;
dial_opts::Opts::WithPeerId(dial_opts::WithPeerId { let condition = dial_opts.peer_condition();
peer_id,
condition,
role_override,
dial_concurrency_factor_override,
})
| dial_opts::Opts::WithPeerIdWithAddresses(dial_opts::WithPeerIdWithAddresses {
peer_id,
condition,
role_override,
dial_concurrency_factor_override,
..
}) => {
// Check [`PeerCondition`] if provided.
let condition_matched = match condition {
PeerCondition::Disconnected => !self.is_connected(&peer_id),
PeerCondition::NotDialing => !self.pool.is_dialing(peer_id),
PeerCondition::Always => true,
};
if !condition_matched {
#[allow(deprecated)]
self.behaviour.inject_dial_failure(
Some(peer_id),
handler,
&DialError::DialPeerConditionFalse(condition),
);
return Err(DialError::DialPeerConditionFalse(condition)); let should_dial = match (condition, peer_id) {
(PeerCondition::Always, _) => true,
(PeerCondition::Disconnected, None) => true,
(PeerCondition::NotDialing, None) => true,
(PeerCondition::Disconnected, Some(peer_id)) => !self.pool.is_connected(peer_id),
(PeerCondition::NotDialing, Some(peer_id)) => !self.pool.is_dialing(peer_id),
};
if !should_dial {
let e = DialError::DialPeerConditionFalse(condition);
#[allow(deprecated)]
self.behaviour.inject_dial_failure(peer_id, handler, &e);
return Err(e);
} }
if let Some(peer_id) = peer_id {
// Check if peer is banned. // Check if peer is banned.
if self.banned_peers.contains(&peer_id) { if self.banned_peers.contains(&peer_id) {
let error = DialError::Banned; let error = DialError::Banned;
@ -551,30 +539,16 @@ where
.inject_dial_failure(Some(peer_id), handler, &error); .inject_dial_failure(Some(peer_id), handler, &error);
return Err(error); return Err(error);
} }
}
// Retrieve the addresses to dial.
let addresses = { let addresses = {
let mut addresses = match swarm_dial_opts.0 { let mut addresses = dial_opts.get_addresses();
dial_opts::Opts::WithPeerId(dial_opts::WithPeerId { .. }) => {
self.behaviour.addresses_of_peer(&peer_id) if let Some(peer_id) = peer_id {
if dial_opts.extend_addresses_through_behaviour() {
addresses.extend(self.behaviour.addresses_of_peer(&peer_id));
} }
dial_opts::Opts::WithPeerIdWithAddresses(
dial_opts::WithPeerIdWithAddresses {
peer_id,
mut addresses,
extend_addresses_through_behaviour,
..
},
) => {
if extend_addresses_through_behaviour {
addresses.extend(self.behaviour.addresses_of_peer(&peer_id))
} }
addresses
}
dial_opts::Opts::WithoutPeerIdWithAddress { .. } => {
unreachable!("Due to outer match.")
}
};
let mut unique_addresses = HashSet::new(); let mut unique_addresses = HashSet::new();
addresses.retain(|addr| { addresses.retain(|addr| {
@ -585,61 +559,18 @@ where
if addresses.is_empty() { if addresses.is_empty() {
let error = DialError::NoAddresses; let error = DialError::NoAddresses;
#[allow(deprecated)] #[allow(deprecated)]
self.behaviour self.behaviour.inject_dial_failure(peer_id, handler, &error);
.inject_dial_failure(Some(peer_id), handler, &error);
return Err(error); return Err(error);
}; };
addresses addresses
}; };
(
Some(peer_id),
Either::Left(addresses.into_iter()),
dial_concurrency_factor_override,
role_override,
)
}
// Dial an unknown peer.
dial_opts::Opts::WithoutPeerIdWithAddress(
dial_opts::WithoutPeerIdWithAddress {
address,
role_override,
},
) => {
// If the address ultimately encapsulates an expected peer ID, dial that peer
// such that any mismatch is detected. We do not "pop off" the `P2p` protocol
// from the address, because it may be used by the `Transport`, i.e. `P2p`
// is a protocol component that can influence any transport, like `libp2p-dns`.
let peer_id = match address
.iter()
.last()
.and_then(|p| {
if let Protocol::P2p(ma) = p {
Some(PeerId::try_from(ma))
} else {
None
}
})
.transpose()
{
Ok(peer_id) => peer_id,
Err(multihash) => return Err(DialError::InvalidPeerId(multihash)),
};
(
peer_id,
Either::Right(iter::once(address)),
None,
role_override,
)
}
};
let dials = addresses let dials = addresses
.into_iter()
.map(|a| match p2p_addr(peer_id, a) { .map(|a| match p2p_addr(peer_id, a) {
Ok(address) => { Ok(address) => {
let dial = match role_override { let dial = match dial_opts.role_override() {
Endpoint::Dialer => self.transport.dial(address.clone()), Endpoint::Dialer => self.transport.dial(address.clone()),
Endpoint::Listener => self.transport.dial_as_listener(address.clone()), Endpoint::Listener => self.transport.dial_as_listener(address.clone()),
}; };
@ -662,8 +593,8 @@ where
dials, dials,
peer_id, peer_id,
handler, handler,
role_override, dial_opts.role_override(),
dial_concurrency_factor_override, dial_opts.dial_concurrency_override(),
) { ) {
Ok(_connection_id) => Ok(()), Ok(_connection_id) => Ok(()),
Err((connection_limit, handler)) => { Err((connection_limit, handler)) => {
@ -1088,9 +1019,9 @@ where
return Some(SwarmEvent::Behaviour(event)) return Some(SwarmEvent::Behaviour(event))
} }
NetworkBehaviourAction::Dial { opts, handler } => { NetworkBehaviourAction::Dial { opts, handler } => {
let peer_id = opts.get_peer_id(); let peer_id = opts.get_or_parse_peer_id();
if let Ok(()) = self.dial_with_handler(opts, handler) { if let Ok(()) = self.dial_with_handler(opts, handler) {
if let Some(peer_id) = peer_id { if let Ok(Some(peer_id)) = peer_id {
return Some(SwarmEvent::Dialing(peer_id)); return Some(SwarmEvent::Dialing(peer_id));
} }
} }
@ -2516,6 +2447,8 @@ mod tests {
_ => panic!("Was expecting the listen address to be reported"), _ => panic!("Was expecting the listen address to be reported"),
})); }));
swarm.listened_addrs.clear(); // This is a hack to actually execute the dial to ourselves which would otherwise be filtered.
swarm.dial(local_address.clone()).unwrap(); swarm.dial(local_address.clone()).unwrap();
let mut got_dial_err = false; let mut got_dial_err = false;