refactor(swarm): express dial logic linearly (#3253)

Previously, the logic within `Swarm::dial` involved fairly convoluted `match` expressions. This patch refactors this function to use new utility functions introduced on `DialOpts` to handle one concern at a time.

This has the advantage that we are covering slightly more cases now. Because we are parsing the `PeerId` only once at the top, checks like banning will now also act on dials that specify the `PeerId` as part of the `/p2p` protocol.
This commit is contained in:
Thomas Eizinger
2022-12-23 03:44:58 +11:00
committed by GitHub
parent 1765ae0395
commit d5f4acc6ed
2 changed files with 156 additions and 123 deletions

View File

@ -20,6 +20,8 @@
// DEALINGS IN THE SOFTWARE.
use libp2p_core::connection::Endpoint;
use libp2p_core::multiaddr::Protocol;
use libp2p_core::multihash::Multihash;
use libp2p_core::{Multiaddr, PeerId};
use std::num::NonZeroU8;
@ -79,6 +81,104 @@ impl DialOpts {
DialOpts(Opts::WithoutPeerIdWithAddress(_)) => None,
}
}
/// Retrieves the [`PeerId`] from the [`DialOpts`] if specified or otherwise tries to parse it
/// from the multihash in the `/p2p` part of the address, if present.
///
/// Note: A [`Multiaddr`] with something else other than a [`PeerId`] within the `/p2p` protocol is invalid as per specification.
/// Unfortunately, we are not making good use of the type system here.
/// Really, this function should be merged with [`DialOpts::get_peer_id`] above.
/// If it weren't for the parsing error, the function signatures would be the same.
///
/// See <https://github.com/multiformats/rust-multiaddr/issues/73>.
pub(crate) fn get_or_parse_peer_id(&self) -> Result<Option<PeerId>, Multihash> {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { peer_id, .. })) => Ok(Some(*peer_id)),
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
peer_id, ..
})) => Ok(Some(*peer_id)),
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
address, ..
})) => {
let peer_id = address
.iter()
.last()
.and_then(|p| {
if let Protocol::P2p(ma) = p {
Some(PeerId::try_from(ma))
} else {
None
}
})
.transpose()?;
Ok(peer_id)
}
}
}
pub(crate) fn get_addresses(&self) -> Vec<Multiaddr> {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { .. })) => vec![],
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
addresses, ..
})) => addresses.clone(),
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
address, ..
})) => vec![address.clone()],
}
}
pub(crate) fn extend_addresses_through_behaviour(&self) -> bool {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { .. })) => true,
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
extend_addresses_through_behaviour,
..
})) => *extend_addresses_through_behaviour,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => true,
}
}
pub(crate) fn peer_condition(&self) -> PeerCondition {
match self {
DialOpts(
Opts::WithPeerId(WithPeerId { condition, .. })
| Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses { condition, .. }),
) => *condition,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => {
PeerCondition::Always
}
}
}
pub(crate) fn dial_concurrency_override(&self) -> Option<NonZeroU8> {
match self {
DialOpts(Opts::WithPeerId(WithPeerId {
dial_concurrency_factor_override,
..
})) => *dial_concurrency_factor_override,
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
dial_concurrency_factor_override,
..
})) => *dial_concurrency_factor_override,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => None,
}
}
pub(crate) fn role_override(&self) -> Endpoint {
match self {
DialOpts(Opts::WithPeerId(WithPeerId { role_override, .. })) => *role_override,
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
role_override,
..
})) => *role_override,
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
role_override,
..
})) => *role_override,
}
}
}
impl From<Multiaddr> for DialOpts {

View File

@ -122,7 +122,6 @@ pub use registry::{AddAddressResult, AddressRecord, AddressScore};
use connection::pool::{EstablishedConnection, Pool, PoolConfig, PoolEvent};
use connection::IncomingInfo;
use dial_opts::{DialOpts, PeerCondition};
use either::Either;
use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream};
use libp2p_core::connection::ConnectionId;
use libp2p_core::muxing::SubstreamBox;
@ -138,7 +137,6 @@ use libp2p_core::{
use registry::{AddressIntoIter, Addresses};
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet};
use std::iter;
use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize};
use std::{
convert::TryFrom,
@ -507,42 +505,32 @@ where
fn dial_with_handler(
&mut self,
swarm_dial_opts: DialOpts,
dial_opts: DialOpts,
handler: <TBehaviour as NetworkBehaviour>::ConnectionHandler,
) -> Result<(), DialError> {
let (peer_id, addresses, dial_concurrency_factor_override, role_override) =
match swarm_dial_opts.0 {
// Dial a known peer.
dial_opts::Opts::WithPeerId(dial_opts::WithPeerId {
peer_id,
condition,
role_override,
dial_concurrency_factor_override,
})
| dial_opts::Opts::WithPeerIdWithAddresses(dial_opts::WithPeerIdWithAddresses {
peer_id,
condition,
role_override,
dial_concurrency_factor_override,
..
}) => {
// Check [`PeerCondition`] if provided.
let condition_matched = match condition {
PeerCondition::Disconnected => !self.is_connected(&peer_id),
PeerCondition::NotDialing => !self.pool.is_dialing(peer_id),
PeerCondition::Always => true,
};
if !condition_matched {
#[allow(deprecated)]
self.behaviour.inject_dial_failure(
Some(peer_id),
handler,
&DialError::DialPeerConditionFalse(condition),
);
let peer_id = dial_opts
.get_or_parse_peer_id()
.map_err(DialError::InvalidPeerId)?;
let condition = dial_opts.peer_condition();
return Err(DialError::DialPeerConditionFalse(condition));
let should_dial = match (condition, peer_id) {
(PeerCondition::Always, _) => true,
(PeerCondition::Disconnected, None) => true,
(PeerCondition::NotDialing, None) => true,
(PeerCondition::Disconnected, Some(peer_id)) => !self.pool.is_connected(peer_id),
(PeerCondition::NotDialing, Some(peer_id)) => !self.pool.is_dialing(peer_id),
};
if !should_dial {
let e = DialError::DialPeerConditionFalse(condition);
#[allow(deprecated)]
self.behaviour.inject_dial_failure(peer_id, handler, &e);
return Err(e);
}
if let Some(peer_id) = peer_id {
// Check if peer is banned.
if self.banned_peers.contains(&peer_id) {
let error = DialError::Banned;
@ -551,30 +539,16 @@ where
.inject_dial_failure(Some(peer_id), handler, &error);
return Err(error);
}
}
// Retrieve the addresses to dial.
let addresses = {
let mut addresses = match swarm_dial_opts.0 {
dial_opts::Opts::WithPeerId(dial_opts::WithPeerId { .. }) => {
self.behaviour.addresses_of_peer(&peer_id)
let mut addresses = dial_opts.get_addresses();
if let Some(peer_id) = peer_id {
if dial_opts.extend_addresses_through_behaviour() {
addresses.extend(self.behaviour.addresses_of_peer(&peer_id));
}
dial_opts::Opts::WithPeerIdWithAddresses(
dial_opts::WithPeerIdWithAddresses {
peer_id,
mut addresses,
extend_addresses_through_behaviour,
..
},
) => {
if extend_addresses_through_behaviour {
addresses.extend(self.behaviour.addresses_of_peer(&peer_id))
}
addresses
}
dial_opts::Opts::WithoutPeerIdWithAddress { .. } => {
unreachable!("Due to outer match.")
}
};
let mut unique_addresses = HashSet::new();
addresses.retain(|addr| {
@ -585,61 +559,18 @@ where
if addresses.is_empty() {
let error = DialError::NoAddresses;
#[allow(deprecated)]
self.behaviour
.inject_dial_failure(Some(peer_id), handler, &error);
self.behaviour.inject_dial_failure(peer_id, handler, &error);
return Err(error);
};
addresses
};
(
Some(peer_id),
Either::Left(addresses.into_iter()),
dial_concurrency_factor_override,
role_override,
)
}
// Dial an unknown peer.
dial_opts::Opts::WithoutPeerIdWithAddress(
dial_opts::WithoutPeerIdWithAddress {
address,
role_override,
},
) => {
// If the address ultimately encapsulates an expected peer ID, dial that peer
// such that any mismatch is detected. We do not "pop off" the `P2p` protocol
// from the address, because it may be used by the `Transport`, i.e. `P2p`
// is a protocol component that can influence any transport, like `libp2p-dns`.
let peer_id = match address
.iter()
.last()
.and_then(|p| {
if let Protocol::P2p(ma) = p {
Some(PeerId::try_from(ma))
} else {
None
}
})
.transpose()
{
Ok(peer_id) => peer_id,
Err(multihash) => return Err(DialError::InvalidPeerId(multihash)),
};
(
peer_id,
Either::Right(iter::once(address)),
None,
role_override,
)
}
};
let dials = addresses
.into_iter()
.map(|a| match p2p_addr(peer_id, a) {
Ok(address) => {
let dial = match role_override {
let dial = match dial_opts.role_override() {
Endpoint::Dialer => self.transport.dial(address.clone()),
Endpoint::Listener => self.transport.dial_as_listener(address.clone()),
};
@ -662,8 +593,8 @@ where
dials,
peer_id,
handler,
role_override,
dial_concurrency_factor_override,
dial_opts.role_override(),
dial_opts.dial_concurrency_override(),
) {
Ok(_connection_id) => Ok(()),
Err((connection_limit, handler)) => {
@ -1088,9 +1019,9 @@ where
return Some(SwarmEvent::Behaviour(event))
}
NetworkBehaviourAction::Dial { opts, handler } => {
let peer_id = opts.get_peer_id();
let peer_id = opts.get_or_parse_peer_id();
if let Ok(()) = self.dial_with_handler(opts, handler) {
if let Some(peer_id) = peer_id {
if let Ok(Some(peer_id)) = peer_id {
return Some(SwarmEvent::Dialing(peer_id));
}
}
@ -2516,6 +2447,8 @@ mod tests {
_ => panic!("Was expecting the listen address to be reported"),
}));
swarm.listened_addrs.clear(); // This is a hack to actually execute the dial to ourselves which would otherwise be filtered.
swarm.dial(local_address.clone()).unwrap();
let mut got_dial_err = false;