// Copyright 2020 Parity Technologies (UK) Ltd. // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. //! A [`ProtocolsHandler`] implementation that combines multiple other `ProtocolsHandler`s //! indexed by some key. use crate::NegotiatedSubstream; use crate::protocols_handler::{ KeepAlive, IntoProtocolsHandler, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol }; use crate::upgrade::{ InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend }; use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{ConnectedPoint, PeerId, upgrade::ProtocolName}; use rand::Rng; use std::{ collections::{HashMap, HashSet}, error, fmt, hash::Hash, iter::{self, FromIterator}, task::{Context, Poll} }; /// A [`ProtocolsHandler`] for multiple other `ProtocolsHandler`s. #[derive(Clone)] pub struct MultiHandler { handlers: HashMap } impl fmt::Debug for MultiHandler where K: fmt::Debug + Eq + Hash, H: fmt::Debug { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("MultiHandler") .field("handlers", &self.handlers) .finish() } } impl MultiHandler where K: Hash + Eq, H: ProtocolsHandler { /// Create and populate a `MultiHandler` from the given handler iterator. /// /// It is an error for any two protocols handlers to share the same protocol name. pub fn try_from_iter(iter: I) -> Result where I: IntoIterator { let m = MultiHandler { handlers: HashMap::from_iter(iter) }; uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().1))?; Ok(m) } } impl ProtocolsHandler for MultiHandler where K: Clone + Hash + Eq + Send + 'static, H: ProtocolsHandler, H::InboundProtocol: InboundUpgradeSend, H::OutboundProtocol: OutboundUpgradeSend { type InEvent = (K, ::InEvent); type OutEvent = (K, ::OutEvent); type Error = ::Error; type InboundProtocol = Upgrade::InboundProtocol>; type OutboundProtocol = ::OutboundProtocol; type OutboundOpenInfo = (K, ::OutboundOpenInfo); fn listen_protocol(&self) -> SubstreamProtocol { let upgrades = self.handlers.iter() .map(|(k, h)| (k.clone(), h.listen_protocol().into_upgrade().1)) .collect(); SubstreamProtocol::new(Upgrade { upgrades }) } fn inject_fully_negotiated_outbound ( &mut self, protocol: ::Output, (key, arg): Self::OutboundOpenInfo ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_fully_negotiated_outbound(protocol, arg) } else { log::error!("inject_fully_negotiated_outbound: no handler for key") } } fn inject_fully_negotiated_inbound ( &mut self, (key, arg): ::Output ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_fully_negotiated_inbound(arg) } else { log::error!("inject_fully_negotiated_inbound: no handler for key") } } fn inject_event(&mut self, (key, event): Self::InEvent) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_event(event) } else { log::error!("inject_event: no handler for key") } } fn inject_dial_upgrade_error ( &mut self, (key, arg): Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error> ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_dial_upgrade_error(arg, error) } else { log::error!("inject_dial_upgrade_error: no handler for protocol") } } fn connection_keep_alive(&self) -> KeepAlive { self.handlers.values() .map(|h| h.connection_keep_alive()) .max() .unwrap_or(KeepAlive::No) } fn poll(&mut self, cx: &mut Context) -> Poll> { // Not always polling handlers in the same order should give anyone the chance to make progress. let pos = rand::thread_rng().gen_range(0, self.handlers.len()); for (k, h) in self.handlers.iter_mut().skip(pos) { if let Poll::Ready(e) = h.poll(cx) { let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p)); return Poll::Ready(e) } } for (k, h) in self.handlers.iter_mut().take(pos) { if let Poll::Ready(e) = h.poll(cx) { let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p)); return Poll::Ready(e) } } Poll::Pending } } /// A [`IntoProtocolsHandler`] for multiple other `IntoProtocolsHandler`s. #[derive(Clone)] pub struct IntoMultiHandler { handlers: HashMap } impl fmt::Debug for IntoMultiHandler where K: fmt::Debug + Eq + Hash, H: fmt::Debug { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("IntoMultiHandler") .field("handlers", &self.handlers) .finish() } } impl IntoMultiHandler where K: Hash + Eq, H: IntoProtocolsHandler { /// Create and populate an `IntoMultiHandler` from the given iterator. /// /// It is an error for any two protocols handlers to share the same protocol name. pub fn try_from_iter(iter: I) -> Result where I: IntoIterator { let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) }; uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?; Ok(m) } } impl IntoProtocolsHandler for IntoMultiHandler where K: Clone + Eq + Hash + Send + 'static, H: IntoProtocolsHandler { type Handler = MultiHandler; fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler { MultiHandler { handlers: self.handlers.into_iter() .map(|(k, h)| (k, h.into_handler(p, c))) .collect() } } fn inbound_protocol(&self) -> ::InboundProtocol { Upgrade { upgrades: self.handlers.iter() .map(|(k, h)| (k.clone(), h.inbound_protocol())) .collect() } } } /// Index and protocol name pair used as `UpgradeInfo::Info`. #[derive(Debug, Clone)] pub struct IndexedProtoName(usize, H); impl ProtocolName for IndexedProtoName { fn protocol_name(&self) -> &[u8] { self.1.protocol_name() } } /// Inbound and outbound upgrade for all `ProtocolsHandler`s. #[derive(Clone)] pub struct Upgrade { upgrades: Vec<(K, H)> } impl fmt::Debug for Upgrade where K: fmt::Debug + Eq + Hash, H: fmt::Debug { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Upgrade") .field("upgrades", &self.upgrades) .finish() } } impl UpgradeInfoSend for Upgrade where H: UpgradeInfoSend, K: Send + 'static { type Info = IndexedProtoName; type InfoIter = std::vec::IntoIter; fn protocol_info(&self) -> Self::InfoIter { self.upgrades.iter().enumerate() .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info())) .flatten() .map(|(i, h)| IndexedProtoName(i, h)) .collect::>() .into_iter() } } impl InboundUpgradeSend for Upgrade where H: InboundUpgradeSend, K: Send + 'static { type Output = (K, ::Output); type Error = (K, ::Error); type Future = BoxFuture<'static, Result>; fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future { let IndexedProtoName(index, info) = info; let (key, upgrade) = self.upgrades.remove(index); upgrade.upgrade_inbound(resource, info) .map(move |out| { match out { Ok(o) => Ok((key, o)), Err(e) => Err((key, e)) } }) .boxed() } } impl OutboundUpgradeSend for Upgrade where H: OutboundUpgradeSend, K: Send + 'static { type Output = (K, ::Output); type Error = (K, ::Error); type Future = BoxFuture<'static, Result>; fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future { let IndexedProtoName(index, info) = info; let (key, upgrade) = self.upgrades.remove(index); upgrade.upgrade_outbound(resource, info) .map(move |out| { match out { Ok(o) => Ok((key, o)), Err(e) => Err((key, e)) } }) .boxed() } } /// Check that no two protocol names are equal. fn uniq_proto_names(iter: I) -> Result<(), DuplicateProtonameError> where I: Iterator, T: UpgradeInfoSend { let mut set = HashSet::new(); for infos in iter { for i in infos.protocol_info() { let v = Vec::from(i.protocol_name()); if set.contains(&v) { return Err(DuplicateProtonameError(v)) } else { set.insert(v); } } } Ok(()) } /// It is an error if two handlers share the same protocol name. #[derive(Debug, Clone)] pub struct DuplicateProtonameError(Vec); impl DuplicateProtonameError { /// The protocol name bytes that occured in more than one handler. pub fn protocol_name(&self) -> &[u8] { &self.0 } } impl fmt::Display for DuplicateProtonameError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Ok(s) = std::str::from_utf8(&self.0) { write!(f, "duplicate protocol name: {}", s) } else { write!(f, "duplicate protocol name: {:?}", self.0) } } } impl error::Error for DuplicateProtonameError {}