protocols/mdns: Make libp2p-mdns socket agnostic (#1699)

Allow libp2p-mdns to use either async-std or tokio to drive required UDP
socket.

Co-authored-by: Roman Borschel <romanb@users.noreply.github.com>
This commit is contained in:
Max Inden
2020-08-18 14:51:03 +02:00
committed by GitHub
parent b4ad2d6297
commit cbdbf656c0
7 changed files with 370 additions and 302 deletions

View File

@ -19,7 +19,6 @@
// DEALINGS IN THE SOFTWARE.
use crate::{SERVICE_NAME, META_QUERY_SERVICE, dns};
use async_std::net::UdpSocket;
use dns_parser::{Packet, RData};
use either::Either::{Left, Right};
use futures::{future, prelude::*};
@ -37,6 +36,9 @@ lazy_static! {
));
}
macro_rules! codegen {
($feature_name:expr, $service_name:ident, $udp_socket:ty, $udp_socket_from_std:tt) => {
/// A running service that discovers libp2p peers and responds to other libp2p peers' queries on
/// the local network.
///
@ -62,13 +64,16 @@ lazy_static! {
/// # use futures::prelude::*;
/// # use futures::executor::block_on;
/// # use libp2p_core::{identity, Multiaddr, PeerId};
/// # use libp2p_mdns::service::{MdnsService, MdnsPacket, build_query_response, build_service_discovery_response};
/// # use libp2p_mdns::service::{MdnsPacket, build_query_response, build_service_discovery_response};
/// # use std::{io, time::Duration, task::Poll};
/// # fn main() {
/// # let my_peer_id = PeerId::from(identity::Keypair::generate_ed25519().public());
/// # let my_listened_addrs: Vec<Multiaddr> = vec![];
/// # block_on(async {
/// let mut service = MdnsService::new().expect("Error while creating mDNS service");
/// # async {
/// # #[cfg(feature = "async-std")]
/// # let mut service = libp2p_mdns::service::MdnsService::new().unwrap();
/// # #[cfg(feature = "tokio")]
/// # let mut service = libp2p_mdns::service::TokioMdnsService::new().unwrap();
/// let _future_to_poll = async {
/// let (mut service, packet) = service.next().await;
///
@ -100,13 +105,16 @@ lazy_static! {
/// }
/// }
/// };
/// # })
/// # };
/// # }
pub struct MdnsService {
#[cfg_attr(docsrs, doc(cfg(feature = $feature_name)))]
pub struct $service_name {
/// Main socket for listening.
socket: UdpSocket,
socket: $udp_socket,
/// Socket for sending queries on the network.
query_socket: UdpSocket,
query_socket: $udp_socket,
/// Interval for sending queries.
query_interval: Interval,
/// Whether we send queries on the network at all.
@ -121,20 +129,20 @@ pub struct MdnsService {
query_send_buffers: Vec<Vec<u8>>,
}
impl MdnsService {
impl $service_name {
/// Starts a new mDNS service.
pub fn new() -> io::Result<MdnsService> {
pub fn new() -> io::Result<$service_name> {
Self::new_inner(false)
}
/// Same as `new`, but we don't automatically send queries on the network.
pub fn silent() -> io::Result<MdnsService> {
pub fn silent() -> io::Result<$service_name> {
Self::new_inner(true)
}
/// Starts a new mDNS service.
fn new_inner(silent: bool) -> io::Result<MdnsService> {
let socket = {
fn new_inner(silent: bool) -> io::Result<$service_name> {
let std_socket = {
#[cfg(unix)]
fn platform_specific(s: &net2::UdpBuilder) -> io::Result<()> {
net2::unix::UnixUdpBuilderExt::reuse_port(s, true)?;
@ -148,17 +156,21 @@ impl MdnsService {
builder.bind(("0.0.0.0", 5353))?
};
let socket = UdpSocket::from(socket);
let socket = $udp_socket_from_std(std_socket)?;
// Given that we pass an IP address to bind, which does not need to be resolved, we can
// use std::net::UdpSocket::bind, instead of its async counterpart from async-std.
let query_socket = $udp_socket_from_std(
std::net::UdpSocket::bind((Ipv4Addr::from([0u8, 0, 0, 0]), 0u16))?,
)?;
socket.set_multicast_loop_v4(true)?;
socket.set_multicast_ttl_v4(255)?;
// TODO: correct interfaces?
socket.join_multicast_v4(From::from([224, 0, 0, 251]), Ipv4Addr::UNSPECIFIED)?;
Ok(MdnsService {
Ok($service_name {
socket,
// Given that we pass an IP address to bind, which does not need to be resolved, we can
// use std::net::UdpSocket::bind, instead of its async counterpart from async-std.
query_socket: std::net::UdpSocket::bind((Ipv4Addr::from([0u8, 0, 0, 0]), 0u16))?.into(),
query_socket,
query_interval: Interval::new_at(Instant::now(), Duration::from_secs(20)),
silent,
recv_buffer: [0; 2048],
@ -266,14 +278,24 @@ impl MdnsService {
}
}
impl fmt::Debug for MdnsService {
impl fmt::Debug for $service_name {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("MdnsService")
fmt.debug_struct("$service_name")
.field("silent", &self.silent)
.finish()
}
}
};
}
#[cfg(feature = "async-std")]
codegen!("async-std", MdnsService, async_std::net::UdpSocket, (|socket| Ok::<_, std::io::Error>(async_std::net::UdpSocket::from(socket))));
#[cfg(feature = "tokio")]
codegen!("tokio", TokioMdnsService, tokio::net::UdpSocket, (|socket| tokio::net::UdpSocket::from_std(socket)));
/// A valid mDNS packet received by the service.
#[derive(Debug)]
pub enum MdnsPacket {
@ -556,97 +578,121 @@ impl fmt::Debug for MdnsPeer {
#[cfg(test)]
mod tests {
use futures::executor::block_on;
use libp2p_core::{PeerId, multiaddr::multihash};
use std::{io::{Error, ErrorKind}, time::Duration};
use wasm_timer::ext::TryFutureExt;
use crate::service::{MdnsPacket, MdnsService};
macro_rules! testgen {
($runtime_name:ident, $service_name:ty, $block_on_fn:tt) => {
mod $runtime_name {
use libp2p_core::{PeerId, multiaddr::multihash};
use std::time::Duration;
use crate::service::MdnsPacket;
fn discover(peer_id: PeerId) {
block_on(async {
let mut service = MdnsService::new().unwrap();
loop {
let next = service.next().await;
service = next.0;
fn discover(peer_id: PeerId) {
let fut = async {
let mut service = <$service_name>::new().unwrap();
match next.1 {
MdnsPacket::Query(query) => {
let resp = crate::dns::build_query_response(
query.query_id(),
peer_id.clone(),
vec![].into_iter(),
Duration::from_secs(120),
).unwrap();
service.enqueue_response(resp);
loop {
let next = service.next().await;
service = next.0;
match next.1 {
MdnsPacket::Query(query) => {
let resp = crate::dns::build_query_response(
query.query_id(),
peer_id.clone(),
vec![].into_iter(),
Duration::from_secs(120),
).unwrap();
service.enqueue_response(resp);
}
MdnsPacket::Response(response) => {
for peer in response.discovered_peers() {
if peer.id() == &peer_id {
return;
}
}
}
MdnsPacket::ServiceDiscovery(_) => panic!(
"did not expect a service discovery packet",
)
}
MdnsPacket::Response(response) => {
for peer in response.discovered_peers() {
if peer.id() == &peer_id {
}
};
$block_on_fn(Box::pin(fut));
}
// As of today the underlying UDP socket is not stubbed out. Thus tests run in parallel to
// this unit tests inter fear with it. Test needs to be run in sequence to ensure test
// properties.
#[test]
fn respect_query_interval() {
let own_ips: Vec<std::net::IpAddr> = get_if_addrs::get_if_addrs().unwrap()
.into_iter()
.map(|i| i.addr.ip())
.collect();
let fut = async {
let mut service = <$service_name>::new().unwrap();
let mut sent_queries = vec![];
loop {
let next = service.next().await;
service = next.0;
match next.1 {
MdnsPacket::Query(query) => {
// Ignore queries from other nodes.
let source_ip = query.remote_addr().ip();
if !own_ips.contains(&source_ip) {
continue;
}
sent_queries.push(query);
if sent_queries.len() > 1 {
return;
}
}
// Ignore response packets. We don't stub out the UDP socket, thus this is
// either random noise from the network, or noise from other unit tests
// running in parallel.
MdnsPacket::Response(_) => {},
MdnsPacket::ServiceDiscovery(_) => {
panic!("Did not expect a service discovery packet.");
},
}
MdnsPacket::ServiceDiscovery(_) => panic!("did not expect a service discovery packet")
}
}
})
};
$block_on_fn(Box::pin(fut));
}
#[test]
fn discover_normal_peer_id() {
discover(PeerId::random())
}
#[test]
fn discover_long_peer_id() {
let max_value = String::from_utf8(vec![b'f'; 42]).unwrap();
let hash = multihash::Identity::digest(max_value.as_ref());
discover(PeerId::from_multihash(hash).unwrap())
}
}
}
}
// As of today the underlying UDP socket is not stubbed out. Thus tests run in parallel to this
// unit tests inter fear with it. Test needs to be run in sequence to ensure test properties.
#[test]
fn respect_query_interval() {
let own_ips: Vec<std::net::IpAddr> = get_if_addrs::get_if_addrs().unwrap()
.into_iter()
.map(|i| i.addr.ip())
.collect();
#[cfg(feature = "async-std")]
testgen!(
async_std,
crate::service::MdnsService,
(|fut| async_std::task::block_on::<_, ()>(fut))
);
let fut = async {
let mut service = MdnsService::new().unwrap();
let mut sent_queries = vec![];
loop {
let next = service.next().await;
service = next.0;
match next.1 {
MdnsPacket::Query(query) => {
// Ignore queries from other nodes.
let source_ip = query.remote_addr().ip();
if !own_ips.contains(&source_ip) {
continue;
}
sent_queries.push(query);
if sent_queries.len() > 1 {
return Ok(())
}
}
// Ignore response packets. We don't stub out the UDP socket, thus this is
// either random noise from the network, or noise from other unit tests running
// in parallel.
MdnsPacket::Response(_) => {},
MdnsPacket::ServiceDiscovery(_) => {
return Err(Error::new(ErrorKind::Other, "did not expect a service discovery packet"));
},
}
}
};
// TODO: This might be too long for a unit test.
block_on(fut.timeout(Duration::from_secs(41))).unwrap();
}
#[test]
fn discover_normal_peer_id() {
discover(PeerId::random())
}
#[test]
fn discover_long_peer_id() {
let max_value = String::from_utf8(vec![b'f'; 42]).unwrap();
let hash = multihash::Identity::digest(max_value.as_ref());
discover(PeerId::from_multihash(hash).unwrap())
}
#[cfg(feature = "tokio")]
testgen!(
tokio,
crate::service::TokioMdnsService,
(|fut| tokio::runtime::Runtime::new().unwrap().block_on::<futures::future::BoxFuture<()>>(fut))
);
}