protocols/mdns: Support multiple network interfaces (#2383)

Handling multiple interfaces in mdns. The socket logic was moved into an
instance while the mdns behaviour watches for interface changes and creates new
instances with a dedicated send/recv socket.

Co-authored-by: Max Inden <mail@max-inden.de>
This commit is contained in:
David Craven
2021-12-29 19:02:20 +01:00
committed by GitHub
parent 23f6b00b66
commit df2e5a591e
9 changed files with 404 additions and 352 deletions

View File

@ -18,79 +18,35 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::dns::{build_query, build_query_response, build_service_discovery_response};
use crate::query::MdnsPacket;
use crate::IPV4_MDNS_MULTICAST_ADDRESS;
use async_io::{Async, Timer};
mod iface;
use self::iface::InterfaceState;
use crate::MdnsConfig;
use async_io::Timer;
use futures::prelude::*;
use if_watch::{IfEvent, IfWatcher};
use libp2p_core::connection::ListenerId;
use libp2p_core::{address_translation, multiaddr::Protocol, Multiaddr, PeerId};
use libp2p_core::{Multiaddr, PeerId};
use libp2p_swarm::{
protocols_handler::DummyProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction,
PollParameters, ProtocolsHandler,
};
use smallvec::SmallVec;
use socket2::{Domain, Socket, Type};
use std::{
cmp,
collections::VecDeque,
fmt, io, iter,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
pin::Pin,
task::Context,
task::Poll,
time::{Duration, Instant},
};
/// Configuration for mDNS.
#[derive(Clone, Debug)]
pub struct MdnsConfig {
/// TTL to use for mdns records.
pub ttl: Duration,
/// Interval at which to poll the network for new peers. This isn't
/// necessary during normal operation but avoids the case that an
/// initial packet was lost and not discovering any peers until a new
/// peer joins the network. Receiving an mdns packet resets the timer
/// preventing unnecessary traffic.
pub query_interval: Duration,
/// IP address for multicast.
pub multicast_addr: IpAddr,
}
impl Default for MdnsConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(6 * 60),
query_interval: Duration::from_secs(5 * 60),
multicast_addr: *IPV4_MDNS_MULTICAST_ADDRESS,
}
}
}
use std::collections::hash_map::{Entry, HashMap};
use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant};
/// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds
/// them to the topology.
#[derive(Debug)]
pub struct Mdns {
/// Main socket for listening.
recv_socket: Async<UdpSocket>,
/// Query socket for making queries.
send_socket: Async<UdpSocket>,
/// InterfaceState config.
config: MdnsConfig,
/// Iface watcher.
if_watch: IfWatcher,
/// Buffer used for receiving data from the main socket.
/// RFC6762 discourages packets larger than the interface MTU, but allows sizes of up to 9000
/// bytes, if it can be ensured that all participating devices can handle such large packets.
/// For computers with several interfaces and IP addresses responses can easily reach sizes in
/// the range of 3000 bytes, so 4096 seems sensible for now. For more information see
/// [rfc6762](https://tools.ietf.org/html/rfc6762#page-46).
recv_buffer: [u8; 4096],
/// Buffers pending to send on the main socket.
send_buffer: VecDeque<Vec<u8>>,
/// Mdns interface states.
iface_states: HashMap<IpAddr, InterfaceState>,
/// List of nodes that we have discovered, the address, and when their TTL expires.
///
@ -102,77 +58,18 @@ pub struct Mdns {
///
/// `None` if `discovered_nodes` is empty.
closest_expiration: Option<Timer>,
/// Queued events.
events: VecDeque<MdnsEvent>,
/// Discovery interval.
query_interval: Duration,
/// Record ttl.
ttl: Duration,
/// Discovery timer.
timeout: Timer,
// Multicast address.
multicast_addr: IpAddr,
}
impl Mdns {
/// Builds a new `Mdns` behaviour.
pub async fn new(config: MdnsConfig) -> io::Result<Self> {
let recv_socket = match config.multicast_addr {
IpAddr::V4(_) => {
let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(socket2::Protocol::UDP))?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
socket.set_reuse_port(true)?;
socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 5353).into())?;
socket.set_multicast_loop_v4(true)?;
socket.set_multicast_ttl_v4(255)?;
Async::new(UdpSocket::from(socket))?
}
IpAddr::V6(_) => {
let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
socket.set_reuse_port(true)?;
socket.bind(&SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 5353).into())?;
socket.set_multicast_loop_v6(true)?;
Async::new(UdpSocket::from(socket))?
}
};
let send_socket = {
let addr = match config.multicast_addr {
IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
};
let socket = std::net::UdpSocket::bind(addr)?;
Async::new(socket)?
};
let if_watch = if_watch::IfWatcher::new().await?;
// randomize timer to prevent all converging and firing at the same time.
let query_interval = {
use rand::Rng;
let mut rng = rand::thread_rng();
let jitter = rng.gen_range(0..100);
config.query_interval + Duration::from_millis(jitter)
};
Ok(Self {
recv_socket,
send_socket,
config,
if_watch,
recv_buffer: [0; 4096],
send_buffer: Default::default(),
discovered_nodes: SmallVec::new(),
closest_expiration: None,
events: Default::default(),
query_interval,
ttl: config.ttl,
timeout: Timer::interval(query_interval),
multicast_addr: config.multicast_addr,
iface_states: Default::default(),
discovered_nodes: Default::default(),
closest_expiration: Default::default(),
})
}
@ -186,85 +83,15 @@ impl Mdns {
self.discovered_nodes.iter().map(|(p, _, _)| p)
}
fn reset_timer(&mut self) {
self.timeout.set_interval(self.query_interval);
}
fn fire_timer(&mut self) {
self.timeout
.set_interval_at(Instant::now(), self.query_interval);
}
fn inject_mdns_packet(&mut self, packet: MdnsPacket, params: &impl PollParameters) {
match packet {
MdnsPacket::Query(query) => {
self.reset_timer();
log::trace!("sending response");
for packet in build_query_response(
query.query_id(),
*params.local_peer_id(),
params.listened_addresses(),
self.ttl,
) {
self.send_buffer.push_back(packet);
}
}
MdnsPacket::Response(response) => {
// We replace the IP address with the address we observe the
// remote as and the address they listen on.
let obs_ip = Protocol::from(response.remote_addr().ip());
let obs_port = Protocol::Udp(response.remote_addr().port());
let observed: Multiaddr = iter::once(obs_ip).chain(iter::once(obs_port)).collect();
let mut discovered: SmallVec<[_; 4]> = SmallVec::new();
for peer in response.discovered_peers() {
if peer.id() == params.local_peer_id() {
continue;
}
let new_expiration = Instant::now() + peer.ttl();
let mut addrs: Vec<Multiaddr> = Vec::new();
for addr in peer.addresses() {
if let Some(new_addr) = address_translation(addr, &observed) {
addrs.push(new_addr.clone())
}
addrs.push(addr.clone())
}
for addr in addrs {
if let Some((_, _, cur_expires)) = self
.discovered_nodes
.iter_mut()
.find(|(p, a, _)| p == peer.id() && *a == addr)
{
*cur_expires = cmp::max(*cur_expires, new_expiration);
} else {
self.discovered_nodes
.push((*peer.id(), addr.clone(), new_expiration));
discovered.push((*peer.id(), addr));
}
}
}
self.closest_expiration = self
.discovered_nodes
.iter()
.fold(None, |exp, &(_, _, elem_exp)| {
Some(exp.map(|exp| cmp::min(exp, elem_exp)).unwrap_or(elem_exp))
})
.map(Timer::at);
self.events
.push_back(MdnsEvent::Discovered(DiscoveredAddrsIter {
inner: discovered.into_iter(),
}));
}
MdnsPacket::ServiceDiscovery(disc) => {
let resp = build_service_discovery_response(disc.query_id(), self.ttl);
self.send_buffer.push_back(resp);
/// Expires a node before the ttl.
pub fn expire_node(&mut self, peer_id: &PeerId) {
let now = Instant::now();
for (peer, _addr, expires) in &mut self.discovered_nodes {
if peer == peer_id {
*expires = now;
}
}
self.closest_expiration = Some(Timer::at(now));
}
}
@ -277,10 +104,9 @@ impl NetworkBehaviour for Mdns {
}
fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec<Multiaddr> {
let now = Instant::now();
self.discovered_nodes
.iter()
.filter(move |(p, _, expires)| p == peer_id && *expires > now)
.filter(|(peer, _, _)| peer == peer_id)
.map(|(_, addr, _)| addr.clone())
.collect()
}
@ -295,7 +121,14 @@ impl NetworkBehaviour for Mdns {
}
fn inject_new_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) {
self.fire_timer();
log::trace!("waking interface state because listening address changed");
for (_, iface) in &mut self.iface_states {
iface.fire_timer();
}
}
fn inject_disconnected(&mut self, peer: &PeerId) {
self.expire_node(peer);
}
fn poll(
@ -303,123 +136,83 @@ impl NetworkBehaviour for Mdns {
cx: &mut Context<'_>,
params: &mut impl PollParameters,
) -> Poll<NetworkBehaviourAction<Self::OutEvent, DummyProtocolsHandler>> {
// Poll ifwatch.
while let Poll::Ready(event) = Pin::new(&mut self.if_watch).poll(cx) {
let socket = self.recv_socket.get_ref();
match event {
Ok(IfEvent::Up(inet)) => {
if inet.addr().is_loopback() {
let addr = inet.addr();
if addr.is_loopback() {
continue;
}
match self.multicast_addr {
IpAddr::V4(multicast) => {
if let IpAddr::V4(addr) = inet.addr() {
log::trace!("joining multicast on iface {}", addr);
if let Err(err) = socket.join_multicast_v4(&multicast, &addr) {
log::error!("join multicast failed: {}", err);
} else {
self.fire_timer();
}
}
}
IpAddr::V6(multicast) => {
if let IpAddr::V6(addr) = inet.addr() {
log::trace!("joining multicast on iface {}", addr);
if let Err(err) = socket.join_multicast_v6(&multicast, 0) {
log::error!("join multicast failed: {}", err);
} else {
self.fire_timer();
}
if addr.is_ipv4() && self.config.enable_ipv6
|| addr.is_ipv6() && !self.config.enable_ipv6
{
continue;
}
if let Entry::Vacant(e) = self.iface_states.entry(addr) {
match InterfaceState::new(addr, self.config.clone()) {
Ok(iface_state) => {
e.insert(iface_state);
}
Err(err) => log::error!("failed to create `InterfaceState`: {}", err),
}
}
}
Ok(IfEvent::Down(inet)) => {
if inet.addr().is_loopback() {
continue;
}
match self.multicast_addr {
IpAddr::V4(multicast) => {
if let IpAddr::V4(addr) = inet.addr() {
log::trace!("leaving multicast on iface {}", addr);
if let Err(err) = socket.leave_multicast_v4(&multicast, &addr) {
log::error!("leave multicast failed: {}", err);
}
}
}
IpAddr::V6(multicast) => {
if let IpAddr::V6(addr) = inet.addr() {
log::trace!("leaving multicast on iface {}", addr);
if let Err(err) = socket.leave_multicast_v6(&multicast, 0) {
log::error!("leave multicast failed: {}", err);
}
}
}
if self.iface_states.contains_key(&inet.addr()) {
log::info!("dropping instance {}", inet.addr());
self.iface_states.remove(&inet.addr());
}
}
Err(err) => log::error!("if watch returned an error: {}", err),
}
}
// Poll receive socket.
while self.recv_socket.poll_readable(cx).is_ready() {
match self
.recv_socket
.recv_from(&mut self.recv_buffer)
.now_or_never()
{
Some(Ok((len, from))) => {
if let Some(packet) = MdnsPacket::new_from_bytes(&self.recv_buffer[..len], from)
{
self.inject_mdns_packet(packet, params);
}
}
Some(Err(err)) => log::error!("Failed reading datagram: {}", err),
_ => {}
}
}
// Send responses.
while self.send_socket.poll_writable(cx).is_ready() {
if let Some(packet) = self.send_buffer.pop_front() {
match self
.send_socket
.send_to(&packet, SocketAddr::new(self.multicast_addr, 5353))
.now_or_never()
{
Some(Ok(_)) => {}
Some(Err(err)) => log::error!("{}", err),
None => self.send_buffer.push_front(packet),
}
} else if Pin::new(&mut self.timeout).poll_next(cx).is_ready() {
log::trace!("sending query");
self.send_buffer.push_back(build_query());
} else {
break;
}
}
// Emit discovered event.
if let Some(event) = self.events.pop_front() {
let mut discovered = SmallVec::<[(PeerId, Multiaddr); 4]>::new();
for (_, iface_state) in &mut self.iface_states {
while let Some((peer, addr, expiration)) = iface_state.poll(cx, params) {
if let Some((_, _, cur_expires)) = self
.discovered_nodes
.iter_mut()
.find(|(p, a, _)| *p == peer && *a == addr)
{
*cur_expires = cmp::max(*cur_expires, expiration);
} else {
log::info!("discovered: {} {}", peer, addr);
self.discovered_nodes.push((peer, addr.clone(), expiration));
discovered.push((peer, addr));
}
}
}
if !discovered.is_empty() {
let event = MdnsEvent::Discovered(DiscoveredAddrsIter {
inner: discovered.into_iter(),
});
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event));
}
// Emit expired event.
if let Some(ref mut closest_expiration) = self.closest_expiration {
if let Poll::Ready(now) = Pin::new(closest_expiration).poll(cx) {
let mut expired = SmallVec::<[(PeerId, Multiaddr); 4]>::new();
while let Some(pos) = self
.discovered_nodes
.iter()
.position(|(_, _, exp)| *exp <= now)
{
let (peer_id, addr, _) = self.discovered_nodes.remove(pos);
expired.push((peer_id, addr));
}
if !expired.is_empty() {
let event = MdnsEvent::Expired(ExpiredAddrsIter {
inner: expired.into_iter(),
});
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event));
}
let now = Instant::now();
let mut closest_expiration = None;
let mut expired = SmallVec::<[(PeerId, Multiaddr); 4]>::new();
self.discovered_nodes.retain(|(peer, addr, expiration)| {
if *expiration <= now {
log::info!("expired: {} {}", peer, addr);
expired.push((*peer, addr.clone()));
return false;
}
closest_expiration = Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
true
});
if !expired.is_empty() {
let event = MdnsEvent::Expired(ExpiredAddrsIter {
inner: expired.into_iter(),
});
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event));
}
if let Some(closest_expiration) = closest_expiration {
let mut timer = Timer::at(closest_expiration);
let _ = Pin::new(&mut timer).poll(cx);
self.closest_expiration = Some(timer);
}
Poll::Pending
}