From cbdbf656c0bf1930cb2c484ab8a49ac91c3daffc Mon Sep 17 00:00:00 2001 From: Max Inden Date: Tue, 18 Aug 2020 14:51:03 +0200 Subject: [PATCH] protocols/mdns: Make libp2p-mdns socket agnostic (#1699) Allow libp2p-mdns to use either async-std or tokio to drive required UDP socket. Co-authored-by: Roman Borschel --- Cargo.toml | 6 +- protocols/mdns/CHANGELOG.md | 3 + protocols/mdns/Cargo.toml | 3 +- protocols/mdns/src/behaviour.rs | 398 ++++++++++++++++---------------- protocols/mdns/src/lib.rs | 8 +- protocols/mdns/src/service.rs | 250 ++++++++++++-------- src/lib.rs | 4 +- 7 files changed, 370 insertions(+), 302 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 09af466d..c217d138 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ default = [ "identify", "kad", "gossipsub", - "mdns", + "mdns-async-std", "mplex", "noise", "ping", @@ -38,7 +38,8 @@ floodsub = ["libp2p-floodsub"] identify = ["libp2p-identify"] kad = ["libp2p-kad"] gossipsub = ["libp2p-gossipsub"] -mdns = ["libp2p-mdns"] +mdns-async-std = ["libp2p-mdns", "libp2p-mdns/async-std"] +mdns-tokio = ["libp2p-mdns", "libp2p-mdns/tokio"] mplex = ["libp2p-mplex"] noise = ["libp2p-noise"] ping = ["libp2p-ping"] @@ -96,6 +97,7 @@ libp2p-websocket = { version = "0.22.0", path = "transports/websocket", optional [dev-dependencies] async-std = "1.6.2" env_logger = "0.7.1" +tokio = { version = "0.2", features = ["io-util", "io-std", "stream"] } [workspace] members = [ diff --git a/protocols/mdns/CHANGELOG.md b/protocols/mdns/CHANGELOG.md index 1d8e9ad6..1898bbe6 100644 --- a/protocols/mdns/CHANGELOG.md +++ b/protocols/mdns/CHANGELOG.md @@ -2,6 +2,9 @@ - Bump `libp2p-core` and `libp2p-swarm` dependencies. +- Allow libp2p-mdns to use either async-std or tokio to drive required UDP + socket ([PR 1699](https://github.com/libp2p/rust-libp2p/pull/1699)). + # 0.20.0 [2020-07-01] - Updated dependencies. diff --git a/protocols/mdns/Cargo.toml b/protocols/mdns/Cargo.toml index 9d86bd35..f8885673 100644 --- a/protocols/mdns/Cargo.toml +++ b/protocols/mdns/Cargo.toml @@ -10,7 +10,7 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] -async-std = "1.6.2" +async-std = { version = "1.6.2", optional = true } data-encoding = "2.0" dns-parser = "0.8" either = "1.5.3" @@ -22,6 +22,7 @@ log = "0.4" net2 = "0.2" rand = "0.7" smallvec = "1.0" +tokio = { version = "0.2", default-features = false, features = ["udp"], optional = true } void = "1.0" wasm-timer = "0.2.4" diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index abdd4d44..b9b5649d 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::service::{MdnsService, MdnsPacket, build_query_response, build_service_discovery_response}; +use crate::service::{MdnsPacket, build_query_response, build_service_discovery_response}; use futures::prelude::*; use libp2p_core::{ Multiaddr, @@ -41,11 +41,14 @@ use wasm_timer::{Delay, Instant}; const MDNS_RESPONSE_TTL: std::time::Duration = Duration::from_secs(5 * 60); +macro_rules! codegen { + ($feature_name:expr, $behaviour_name:ident, $maybe_busy_wrapper:ident, $service_name:ty) => { + /// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds /// them to the topology. -pub struct Mdns { +pub struct $behaviour_name { /// The inner service. - service: MaybeBusyMdnsService, + service: $maybe_busy_wrapper, /// List of nodes that we have discovered, the address, and when their TTL expires. /// @@ -63,37 +66,37 @@ pub struct Mdns { /// and a `MdnsPacket` (similar to the old Tokio socket send style). The two states are thus `Free` /// with an `MdnsService` or `Busy` with a future returning the original `MdnsService` and an /// `MdnsPacket`. -enum MaybeBusyMdnsService { - Free(MdnsService), - Busy(Pin + Send>>), +enum $maybe_busy_wrapper { + Free($service_name), + Busy(Pin + Send>>), Poisoned, } -impl fmt::Debug for MaybeBusyMdnsService { +impl fmt::Debug for $maybe_busy_wrapper { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - MaybeBusyMdnsService::Free(service) => { - fmt.debug_struct("MaybeBusyMdnsService::Free") + $maybe_busy_wrapper::Free(service) => { + fmt.debug_struct("$maybe_busy_wrapper::Free") .field("service", service) .finish() }, - MaybeBusyMdnsService::Busy(_) => { - fmt.debug_struct("MaybeBusyMdnsService::Busy") + $maybe_busy_wrapper::Busy(_) => { + fmt.debug_struct("$maybe_busy_wrapper::Busy") .finish() } - MaybeBusyMdnsService::Poisoned => { - fmt.debug_struct("MaybeBusyMdnsService::Poisoned") + $maybe_busy_wrapper::Poisoned => { + fmt.debug_struct("$maybe_busy_wrapper::Poisoned") .finish() } } } } -impl Mdns { +impl $behaviour_name { /// Builds a new `Mdns` behaviour. - pub fn new() -> io::Result { - Ok(Mdns { - service: MaybeBusyMdnsService::Free(MdnsService::new()?), + pub fn new() -> io::Result<$behaviour_name> { + Ok($behaviour_name { + service: $maybe_busy_wrapper::Free(<$service_name>::new()?), discovered_nodes: SmallVec::new(), closest_expiration: None, }) @@ -110,6 +113,191 @@ impl Mdns { } } +impl NetworkBehaviour for $behaviour_name { + type ProtocolsHandler = DummyProtocolsHandler; + type OutEvent = MdnsEvent; + + fn new_handler(&mut self) -> Self::ProtocolsHandler { + DummyProtocolsHandler::default() + } + + fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec { + let now = Instant::now(); + self.discovered_nodes + .iter() + .filter(move |(p, _, expires)| p == peer_id && *expires > now) + .map(|(_, addr, _)| addr.clone()) + .collect() + } + + fn inject_connected(&mut self, _: &PeerId) {} + + fn inject_disconnected(&mut self, _: &PeerId) {} + + fn inject_event( + &mut self, + _: PeerId, + _: ConnectionId, + _ev: ::OutEvent, + ) { + void::unreachable(_ev) + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + params: &mut impl PollParameters, + ) -> Poll< + NetworkBehaviourAction< + ::InEvent, + Self::OutEvent, + >, + > { + // Remove expired peers. + if let Some(ref mut closest_expiration) = self.closest_expiration { + match Future::poll(Pin::new(closest_expiration), cx) { + Poll::Ready(Ok(())) => { + let now = Instant::now(); + let mut expired = SmallVec::<[(PeerId, Multiaddr); 4]>::new(); + while let Some(pos) = self.discovered_nodes.iter().position(|(_, _, exp)| *exp < now) { + let (peer_id, addr, _) = self.discovered_nodes.remove(pos); + expired.push((peer_id, addr)); + } + + if !expired.is_empty() { + let event = MdnsEvent::Expired(ExpiredAddrsIter { + inner: expired.into_iter(), + }); + + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); + } + }, + Poll::Pending => (), + Poll::Ready(Err(err)) => warn!("timer has errored: {:?}", err), + } + } + + // Polling the mDNS service, and obtain the list of nodes discovered this round. + let discovered = loop { + let service = mem::replace(&mut self.service, $maybe_busy_wrapper::Poisoned); + + let packet = match service { + $maybe_busy_wrapper::Free(service) => { + self.service = $maybe_busy_wrapper::Busy(Box::pin(service.next())); + continue; + }, + $maybe_busy_wrapper::Busy(mut fut) => { + match fut.as_mut().poll(cx) { + Poll::Ready((service, packet)) => { + self.service = $maybe_busy_wrapper::Free(service); + packet + }, + Poll::Pending => { + self.service = $maybe_busy_wrapper::Busy(fut); + return Poll::Pending; + } + } + }, + $maybe_busy_wrapper::Poisoned => panic!("Mdns poisoned"), + }; + + match packet { + MdnsPacket::Query(query) => { + // MaybeBusyMdnsService should always be Free. + if let $maybe_busy_wrapper::Free(ref mut service) = self.service { + let resp = build_query_response( + query.query_id(), + params.local_peer_id().clone(), + params.listened_addresses().into_iter(), + MDNS_RESPONSE_TTL, + ); + service.enqueue_response(resp.unwrap()); + } else { debug_assert!(false); } + }, + MdnsPacket::Response(response) => { + // We replace the IP address with the address we observe the + // remote as and the address they listen on. + let obs_ip = Protocol::from(response.remote_addr().ip()); + let obs_port = Protocol::Udp(response.remote_addr().port()); + let observed: Multiaddr = iter::once(obs_ip) + .chain(iter::once(obs_port)) + .collect(); + + let mut discovered: SmallVec<[_; 4]> = SmallVec::new(); + for peer in response.discovered_peers() { + if peer.id() == params.local_peer_id() { + continue; + } + + let new_expiration = Instant::now() + peer.ttl(); + + let mut addrs: Vec = Vec::new(); + for addr in peer.addresses() { + if let Some(new_addr) = address_translation(&addr, &observed) { + addrs.push(new_addr.clone()) + } + addrs.push(addr.clone()) + } + + for addr in addrs { + if let Some((_, _, cur_expires)) = self.discovered_nodes.iter_mut() + .find(|(p, a, _)| p == peer.id() && *a == addr) + { + *cur_expires = cmp::max(*cur_expires, new_expiration); + } else { + self.discovered_nodes.push((peer.id().clone(), addr.clone(), new_expiration)); + } + + discovered.push((peer.id().clone(), addr)); + } + } + + break discovered; + }, + MdnsPacket::ServiceDiscovery(disc) => { + // MaybeBusyMdnsService should always be Free. + if let $maybe_busy_wrapper::Free(ref mut service) = self.service { + let resp = build_service_discovery_response( + disc.query_id(), + MDNS_RESPONSE_TTL, + ); + service.enqueue_response(resp); + } else { debug_assert!(false); } + }, + } + }; + + // Getting this far implies that we discovered new nodes. As the final step, we need to + // refresh `closest_expiration`. + self.closest_expiration = self.discovered_nodes.iter() + .fold(None, |exp, &(_, _, elem_exp)| { + Some(exp.map(|exp| cmp::min(exp, elem_exp)).unwrap_or(elem_exp)) + }) + .map(Delay::new_at); + + Poll::Ready(NetworkBehaviourAction::GenerateEvent(MdnsEvent::Discovered(DiscoveredAddrsIter { + inner: discovered.into_iter(), + }))) + } +} + +impl fmt::Debug for $behaviour_name { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Mdns") + .field("service", &self.service) + .finish() + } +} + +}; +} + +#[cfg(feature = "async-std")] +codegen!("async-std", Mdns, MaybeBusyMdnsService, crate::service::MdnsService); + +#[cfg(feature = "tokio")] +codegen!("tokio", TokioMdns, MaybeBusyTokioMdnsService, crate::service::TokioMdnsService); + /// Event that can be produced by the `Mdns` behaviour. #[derive(Debug)] pub enum MdnsEvent { @@ -180,179 +368,3 @@ impl fmt::Debug for ExpiredAddrsIter { .finish() } } - -impl NetworkBehaviour for Mdns { - type ProtocolsHandler = DummyProtocolsHandler; - type OutEvent = MdnsEvent; - - fn new_handler(&mut self) -> Self::ProtocolsHandler { - DummyProtocolsHandler::default() - } - - fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec { - let now = Instant::now(); - self.discovered_nodes - .iter() - .filter(move |(p, _, expires)| p == peer_id && *expires > now) - .map(|(_, addr, _)| addr.clone()) - .collect() - } - - fn inject_connected(&mut self, _: &PeerId) {} - - fn inject_disconnected(&mut self, _: &PeerId) {} - - fn inject_event( - &mut self, - _: PeerId, - _: ConnectionId, - _ev: ::OutEvent, - ) { - void::unreachable(_ev) - } - - fn poll( - &mut self, - cx: &mut Context<'_>, - params: &mut impl PollParameters, - ) -> Poll< - NetworkBehaviourAction< - ::InEvent, - Self::OutEvent, - >, - > { - // Remove expired peers. - if let Some(ref mut closest_expiration) = self.closest_expiration { - match Future::poll(Pin::new(closest_expiration), cx) { - Poll::Ready(Ok(())) => { - let now = Instant::now(); - let mut expired = SmallVec::<[(PeerId, Multiaddr); 4]>::new(); - while let Some(pos) = self.discovered_nodes.iter().position(|(_, _, exp)| *exp < now) { - let (peer_id, addr, _) = self.discovered_nodes.remove(pos); - expired.push((peer_id, addr)); - } - - if !expired.is_empty() { - let event = MdnsEvent::Expired(ExpiredAddrsIter { - inner: expired.into_iter(), - }); - - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); - } - }, - Poll::Pending => (), - Poll::Ready(Err(err)) => warn!("timer has errored: {:?}", err), - } - } - - // Polling the mDNS service, and obtain the list of nodes discovered this round. - let discovered = loop { - let service = mem::replace(&mut self.service, MaybeBusyMdnsService::Poisoned); - - let packet = match service { - MaybeBusyMdnsService::Free(service) => { - self.service = MaybeBusyMdnsService::Busy(Box::pin(service.next())); - continue; - }, - MaybeBusyMdnsService::Busy(mut fut) => { - match fut.as_mut().poll(cx) { - Poll::Ready((service, packet)) => { - self.service = MaybeBusyMdnsService::Free(service); - packet - }, - Poll::Pending => { - self.service = MaybeBusyMdnsService::Busy(fut); - return Poll::Pending; - } - } - }, - MaybeBusyMdnsService::Poisoned => panic!("Mdns poisoned"), - }; - - match packet { - MdnsPacket::Query(query) => { - // MaybeBusyMdnsService should always be Free. - if let MaybeBusyMdnsService::Free(ref mut service) = self.service { - let resp = build_query_response( - query.query_id(), - params.local_peer_id().clone(), - params.listened_addresses().into_iter(), - MDNS_RESPONSE_TTL, - ); - service.enqueue_response(resp.unwrap()); - } else { debug_assert!(false); } - }, - MdnsPacket::Response(response) => { - // We replace the IP address with the address we observe the - // remote as and the address they listen on. - let obs_ip = Protocol::from(response.remote_addr().ip()); - let obs_port = Protocol::Udp(response.remote_addr().port()); - let observed: Multiaddr = iter::once(obs_ip) - .chain(iter::once(obs_port)) - .collect(); - - let mut discovered: SmallVec<[_; 4]> = SmallVec::new(); - for peer in response.discovered_peers() { - if peer.id() == params.local_peer_id() { - continue; - } - - let new_expiration = Instant::now() + peer.ttl(); - - let mut addrs: Vec = Vec::new(); - for addr in peer.addresses() { - if let Some(new_addr) = address_translation(&addr, &observed) { - addrs.push(new_addr.clone()) - } - addrs.push(addr.clone()) - } - - for addr in addrs { - if let Some((_, _, cur_expires)) = self.discovered_nodes.iter_mut() - .find(|(p, a, _)| p == peer.id() && *a == addr) - { - *cur_expires = cmp::max(*cur_expires, new_expiration); - } else { - self.discovered_nodes.push((peer.id().clone(), addr.clone(), new_expiration)); - } - - discovered.push((peer.id().clone(), addr)); - } - } - - break discovered; - }, - MdnsPacket::ServiceDiscovery(disc) => { - // MaybeBusyMdnsService should always be Free. - if let MaybeBusyMdnsService::Free(ref mut service) = self.service { - let resp = build_service_discovery_response( - disc.query_id(), - MDNS_RESPONSE_TTL, - ); - service.enqueue_response(resp); - } else { debug_assert!(false); } - }, - } - }; - - // Getting this far implies that we discovered new nodes. As the final step, we need to - // refresh `closest_expiration`. - self.closest_expiration = self.discovered_nodes.iter() - .fold(None, |exp, &(_, _, elem_exp)| { - Some(exp.map(|exp| cmp::min(exp, elem_exp)).unwrap_or(elem_exp)) - }) - .map(Delay::new_at); - - Poll::Ready(NetworkBehaviourAction::GenerateEvent(MdnsEvent::Discovered(DiscoveredAddrsIter { - inner: discovered.into_iter(), - }))) - } -} - -impl fmt::Debug for Mdns { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Mdns") - .field("service", &self.service) - .finish() - } -} diff --git a/protocols/mdns/src/lib.rs b/protocols/mdns/src/lib.rs index 767ef9d0..292bd01b 100644 --- a/protocols/mdns/src/lib.rs +++ b/protocols/mdns/src/lib.rs @@ -35,8 +35,12 @@ const SERVICE_NAME: &[u8] = b"_p2p._udp.local"; /// Hardcoded name of the service used for DNS-SD. const META_QUERY_SERVICE: &[u8] = b"_services._dns-sd._udp.local"; -pub use self::behaviour::{Mdns, MdnsEvent}; -pub use self::service::MdnsService; +#[cfg(feature = "async-std")] +pub use self::{behaviour::Mdns, service::MdnsService}; +#[cfg(feature = "tokio")] +pub use self::{behaviour::TokioMdns, service::TokioMdnsService}; + +pub use self::behaviour::MdnsEvent; mod behaviour; mod dns; diff --git a/protocols/mdns/src/service.rs b/protocols/mdns/src/service.rs index 45b1132f..b3913812 100644 --- a/protocols/mdns/src/service.rs +++ b/protocols/mdns/src/service.rs @@ -19,7 +19,6 @@ // DEALINGS IN THE SOFTWARE. use crate::{SERVICE_NAME, META_QUERY_SERVICE, dns}; -use async_std::net::UdpSocket; use dns_parser::{Packet, RData}; use either::Either::{Left, Right}; use futures::{future, prelude::*}; @@ -37,6 +36,9 @@ lazy_static! { )); } +macro_rules! codegen { + ($feature_name:expr, $service_name:ident, $udp_socket:ty, $udp_socket_from_std:tt) => { + /// A running service that discovers libp2p peers and responds to other libp2p peers' queries on /// the local network. /// @@ -62,13 +64,16 @@ lazy_static! { /// # use futures::prelude::*; /// # use futures::executor::block_on; /// # use libp2p_core::{identity, Multiaddr, PeerId}; -/// # use libp2p_mdns::service::{MdnsService, MdnsPacket, build_query_response, build_service_discovery_response}; +/// # use libp2p_mdns::service::{MdnsPacket, build_query_response, build_service_discovery_response}; /// # use std::{io, time::Duration, task::Poll}; /// # fn main() { /// # let my_peer_id = PeerId::from(identity::Keypair::generate_ed25519().public()); /// # let my_listened_addrs: Vec = vec![]; -/// # block_on(async { -/// let mut service = MdnsService::new().expect("Error while creating mDNS service"); +/// # async { +/// # #[cfg(feature = "async-std")] +/// # let mut service = libp2p_mdns::service::MdnsService::new().unwrap(); +/// # #[cfg(feature = "tokio")] +/// # let mut service = libp2p_mdns::service::TokioMdnsService::new().unwrap(); /// let _future_to_poll = async { /// let (mut service, packet) = service.next().await; /// @@ -100,13 +105,16 @@ lazy_static! { /// } /// } /// }; -/// # }) +/// # }; /// # } -pub struct MdnsService { +#[cfg_attr(docsrs, doc(cfg(feature = $feature_name)))] +pub struct $service_name { /// Main socket for listening. - socket: UdpSocket, + socket: $udp_socket, + /// Socket for sending queries on the network. - query_socket: UdpSocket, + query_socket: $udp_socket, + /// Interval for sending queries. query_interval: Interval, /// Whether we send queries on the network at all. @@ -121,20 +129,20 @@ pub struct MdnsService { query_send_buffers: Vec>, } -impl MdnsService { +impl $service_name { /// Starts a new mDNS service. - pub fn new() -> io::Result { + pub fn new() -> io::Result<$service_name> { Self::new_inner(false) } /// Same as `new`, but we don't automatically send queries on the network. - pub fn silent() -> io::Result { + pub fn silent() -> io::Result<$service_name> { Self::new_inner(true) } /// Starts a new mDNS service. - fn new_inner(silent: bool) -> io::Result { - let socket = { + fn new_inner(silent: bool) -> io::Result<$service_name> { + let std_socket = { #[cfg(unix)] fn platform_specific(s: &net2::UdpBuilder) -> io::Result<()> { net2::unix::UnixUdpBuilderExt::reuse_port(s, true)?; @@ -148,17 +156,21 @@ impl MdnsService { builder.bind(("0.0.0.0", 5353))? }; - let socket = UdpSocket::from(socket); + let socket = $udp_socket_from_std(std_socket)?; + // Given that we pass an IP address to bind, which does not need to be resolved, we can + // use std::net::UdpSocket::bind, instead of its async counterpart from async-std. + let query_socket = $udp_socket_from_std( + std::net::UdpSocket::bind((Ipv4Addr::from([0u8, 0, 0, 0]), 0u16))?, + )?; + socket.set_multicast_loop_v4(true)?; socket.set_multicast_ttl_v4(255)?; // TODO: correct interfaces? socket.join_multicast_v4(From::from([224, 0, 0, 251]), Ipv4Addr::UNSPECIFIED)?; - Ok(MdnsService { + Ok($service_name { socket, - // Given that we pass an IP address to bind, which does not need to be resolved, we can - // use std::net::UdpSocket::bind, instead of its async counterpart from async-std. - query_socket: std::net::UdpSocket::bind((Ipv4Addr::from([0u8, 0, 0, 0]), 0u16))?.into(), + query_socket, query_interval: Interval::new_at(Instant::now(), Duration::from_secs(20)), silent, recv_buffer: [0; 2048], @@ -266,14 +278,24 @@ impl MdnsService { } } -impl fmt::Debug for MdnsService { +impl fmt::Debug for $service_name { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("MdnsService") + fmt.debug_struct("$service_name") .field("silent", &self.silent) .finish() } } +}; +} + +#[cfg(feature = "async-std")] +codegen!("async-std", MdnsService, async_std::net::UdpSocket, (|socket| Ok::<_, std::io::Error>(async_std::net::UdpSocket::from(socket)))); + +#[cfg(feature = "tokio")] +codegen!("tokio", TokioMdnsService, tokio::net::UdpSocket, (|socket| tokio::net::UdpSocket::from_std(socket))); + + /// A valid mDNS packet received by the service. #[derive(Debug)] pub enum MdnsPacket { @@ -556,97 +578,121 @@ impl fmt::Debug for MdnsPeer { #[cfg(test)] mod tests { - use futures::executor::block_on; - use libp2p_core::{PeerId, multiaddr::multihash}; - use std::{io::{Error, ErrorKind}, time::Duration}; - use wasm_timer::ext::TryFutureExt; - use crate::service::{MdnsPacket, MdnsService}; + macro_rules! testgen { + ($runtime_name:ident, $service_name:ty, $block_on_fn:tt) => { + mod $runtime_name { + use libp2p_core::{PeerId, multiaddr::multihash}; + use std::time::Duration; + use crate::service::MdnsPacket; - fn discover(peer_id: PeerId) { - block_on(async { - let mut service = MdnsService::new().unwrap(); - loop { - let next = service.next().await; - service = next.0; + fn discover(peer_id: PeerId) { + let fut = async { + let mut service = <$service_name>::new().unwrap(); - match next.1 { - MdnsPacket::Query(query) => { - let resp = crate::dns::build_query_response( - query.query_id(), - peer_id.clone(), - vec![].into_iter(), - Duration::from_secs(120), - ).unwrap(); - service.enqueue_response(resp); + loop { + let next = service.next().await; + service = next.0; + + match next.1 { + MdnsPacket::Query(query) => { + let resp = crate::dns::build_query_response( + query.query_id(), + peer_id.clone(), + vec![].into_iter(), + Duration::from_secs(120), + ).unwrap(); + service.enqueue_response(resp); + } + MdnsPacket::Response(response) => { + for peer in response.discovered_peers() { + if peer.id() == &peer_id { + return; + } + } + } + MdnsPacket::ServiceDiscovery(_) => panic!( + "did not expect a service discovery packet", + ) } - MdnsPacket::Response(response) => { - for peer in response.discovered_peers() { - if peer.id() == &peer_id { + } + }; + + $block_on_fn(Box::pin(fut)); + } + + // As of today the underlying UDP socket is not stubbed out. Thus tests run in parallel to + // this unit tests inter fear with it. Test needs to be run in sequence to ensure test + // properties. + #[test] + fn respect_query_interval() { + let own_ips: Vec = get_if_addrs::get_if_addrs().unwrap() + .into_iter() + .map(|i| i.addr.ip()) + .collect(); + + let fut = async { + let mut service = <$service_name>::new().unwrap(); + + let mut sent_queries = vec![]; + + loop { + let next = service.next().await; + service = next.0; + + match next.1 { + MdnsPacket::Query(query) => { + // Ignore queries from other nodes. + let source_ip = query.remote_addr().ip(); + if !own_ips.contains(&source_ip) { + continue; + } + + sent_queries.push(query); + + if sent_queries.len() > 1 { return; } } + // Ignore response packets. We don't stub out the UDP socket, thus this is + // either random noise from the network, or noise from other unit tests + // running in parallel. + MdnsPacket::Response(_) => {}, + MdnsPacket::ServiceDiscovery(_) => { + panic!("Did not expect a service discovery packet."); + }, } - MdnsPacket::ServiceDiscovery(_) => panic!("did not expect a service discovery packet") } - } - }) + }; + + $block_on_fn(Box::pin(fut)); + } + + #[test] + fn discover_normal_peer_id() { + discover(PeerId::random()) + } + + #[test] + fn discover_long_peer_id() { + let max_value = String::from_utf8(vec![b'f'; 42]).unwrap(); + let hash = multihash::Identity::digest(max_value.as_ref()); + discover(PeerId::from_multihash(hash).unwrap()) + } + } + } } - // As of today the underlying UDP socket is not stubbed out. Thus tests run in parallel to this - // unit tests inter fear with it. Test needs to be run in sequence to ensure test properties. - #[test] - fn respect_query_interval() { - let own_ips: Vec = get_if_addrs::get_if_addrs().unwrap() - .into_iter() - .map(|i| i.addr.ip()) - .collect(); + #[cfg(feature = "async-std")] + testgen!( + async_std, + crate::service::MdnsService, + (|fut| async_std::task::block_on::<_, ()>(fut)) + ); - let fut = async { - let mut service = MdnsService::new().unwrap(); - let mut sent_queries = vec![]; - - loop { - let next = service.next().await; - service = next.0; - - match next.1 { - MdnsPacket::Query(query) => { - // Ignore queries from other nodes. - let source_ip = query.remote_addr().ip(); - if !own_ips.contains(&source_ip) { - continue; - } - - sent_queries.push(query); - - if sent_queries.len() > 1 { - return Ok(()) - } - } - // Ignore response packets. We don't stub out the UDP socket, thus this is - // either random noise from the network, or noise from other unit tests running - // in parallel. - MdnsPacket::Response(_) => {}, - MdnsPacket::ServiceDiscovery(_) => { - return Err(Error::new(ErrorKind::Other, "did not expect a service discovery packet")); - }, - } - } - }; - - // TODO: This might be too long for a unit test. - block_on(fut.timeout(Duration::from_secs(41))).unwrap(); - } - - #[test] - fn discover_normal_peer_id() { - discover(PeerId::random()) - } - - #[test] - fn discover_long_peer_id() { - let max_value = String::from_utf8(vec![b'f'; 42]).unwrap(); - let hash = multihash::Identity::digest(max_value.as_ref()); - discover(PeerId::from_multihash(hash).unwrap()) - } + #[cfg(feature = "tokio")] + testgen!( + tokio, + crate::service::TokioMdnsService, + (|fut| tokio::runtime::Runtime::new().unwrap().block_on::>(fut)) + ); } diff --git a/src/lib.rs b/src/lib.rs index b5e46bab..5809386e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -195,8 +195,8 @@ pub use libp2p_gossipsub as gossipsub; #[cfg_attr(docsrs, doc(cfg(feature = "mplex")))] #[doc(inline)] pub use libp2p_mplex as mplex; -#[cfg(feature = "mdns")] -#[cfg_attr(docsrs, doc(cfg(feature = "mdns")))] +#[cfg(any(feature = "mdns-async-std", feature = "mdns-tokio"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "mdns-async-std", feature = "mdns-tokio"))))] #[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] #[doc(inline)] pub use libp2p_mdns as mdns;