Rewrite the MemoryTransport to be similar to the TcpConfig (#951)

* Rewrite the MemoryTransport to be similar to the TcpConfig

* Add small test

* Test and bug fixes
This commit is contained in:
Pierre Krieger
2019-02-18 17:05:50 +01:00
committed by GitHub
parent ca9534a38e
commit 43e4d1f589
7 changed files with 197 additions and 116 deletions

View File

@ -13,6 +13,7 @@ categories = ["network-programming", "asynchronous"]
bs58 = "0.2.0"
bytes = "0.4"
fnv = "1.0"
lazy_static = "1.2"
log = "0.4"
multiaddr = { package = "parity-multiaddr", version = "0.1.0", path = "../misc/multiaddr" }
multihash = { package = "parity-multihash", version = "0.1.0", path = "../misc/multihash" }
@ -21,6 +22,7 @@ futures = { version = "0.1", features = ["use_std"] }
parking_lot = "0.7"
protobuf = "2.3"
quick-error = "1.2"
rand = "0.6"
rw-stream-sink = { version = "0.1.0", path = "../misc/rw-stream-sink" }
smallvec = "0.6"
tokio-executor = "0.1.4"

View File

@ -311,12 +311,12 @@ mod tests {
#[test]
fn incoming_event() {
let (tx, rx) = transport::connector();
let mem_transport = transport::MemoryTransport::default();
let mut listeners = ListenersStream::new(rx);
listeners.listen_on("/memory".parse().unwrap()).unwrap();
let mut listeners = ListenersStream::new(mem_transport);
let actual_addr = listeners.listen_on("/memory/0".parse().unwrap()).unwrap();
let dial = tx.dial("/memory".parse().unwrap()).unwrap();
let dial = mem_transport.dial(actual_addr.clone()).unwrap();
let future = listeners
.into_future()
@ -324,8 +324,8 @@ mod tests {
.and_then(|(event, _)| {
match event {
Some(ListenersEvent::Incoming { listen_addr, upgrade, send_back_addr }) => {
assert_eq!(listen_addr, "/memory".parse().unwrap());
assert_eq!(send_back_addr, "/memory".parse().unwrap());
assert_eq!(listen_addr, actual_addr);
assert_eq!(send_back_addr, actual_addr);
upgrade.map(|_| ()).map_err(|_| panic!())
},
_ => panic!()

View File

@ -64,7 +64,7 @@ fn nat_traversal_transforms_the_observed_address_according_to_the_transport_used
let mut raw_swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random());
let addr1 = "/ip4/127.0.0.1/tcp/1234".parse::<Multiaddr>().expect("bad multiaddr");
// An unrelated outside address is returned as-is, no transform
let outside_addr1 = "/memory".parse::<Multiaddr>().expect("bad multiaddr");
let outside_addr1 = "/memory/0".parse::<Multiaddr>().expect("bad multiaddr");
let addr2 = "/ip4/127.0.0.2/tcp/1234".parse::<Multiaddr>().expect("bad multiaddr");
let outside_addr2 = "/ip4/127.0.0.2/tcp/1234".parse::<Multiaddr>().expect("bad multiaddr");
@ -128,7 +128,7 @@ fn num_incoming_negotiated() {
transport.set_initial_listener_state(ListenerState::Ok(Async::Ready(Some((peer_id, muxer)))));
let mut swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random());
swarm.listen_on("/memory".parse().unwrap()).unwrap();
swarm.listen_on("/memory/0".parse().unwrap()).unwrap();
// no incoming yet
assert_eq!(swarm.incoming_negotiated().count(), 0);
@ -203,7 +203,7 @@ fn querying_for_pending_peer() {
let peer_id = PeerId::random();
let peer = swarm.peer(peer_id.clone());
assert_matches!(peer, Peer::NotConnected(PeerNotConnected{ .. }));
let addr = "/memory".parse().expect("bad multiaddr");
let addr = "/memory/0".parse().expect("bad multiaddr");
let pending_peer = peer.into_not_connected().unwrap().connect(addr, Handler::default());
assert_matches!(pending_peer, PeerPendingConnect { .. });
}
@ -255,7 +255,7 @@ fn poll_with_closed_listener() {
transport.set_initial_listener_state(ListenerState::Ok(Async::Ready(None)));
let mut swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random());
swarm.listen_on("/memory".parse().unwrap()).unwrap();
swarm.listen_on("/memory/0".parse().unwrap()).unwrap();
let mut rt = Runtime::new().unwrap();
let swarm = Arc::new(Mutex::new(swarm));
@ -274,7 +274,7 @@ fn unknown_peer_that_is_unreachable_yields_unknown_peer_dial_error() {
let mut transport = DummyTransport::new();
transport.make_dial_fail();
let mut swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random());
let addr = "/memory".parse::<Multiaddr>().expect("bad multiaddr");
let addr = "/memory/0".parse::<Multiaddr>().expect("bad multiaddr");
let handler = Handler::default();
let dial_result = swarm.dial(addr, handler);
assert!(dial_result.is_ok());
@ -311,7 +311,7 @@ fn known_peer_that_is_unreachable_yields_dial_error() {
let mut swarm1 = swarm1.lock();
let peer = swarm1.peer(peer_id.clone());
assert_matches!(peer, Peer::NotConnected(PeerNotConnected{ .. }));
let addr = "/memory".parse::<Multiaddr>().expect("bad multiaddr");
let addr = "/memory/0".parse::<Multiaddr>().expect("bad multiaddr");
let pending_peer = peer.into_not_connected().unwrap().connect(addr, Handler::default());
assert_matches!(pending_peer, PeerPendingConnect { .. });
}
@ -466,7 +466,7 @@ fn limit_incoming_connections() {
let mut swarm = RawSwarm::<_, _, _, Handler, _>::new_with_incoming_limit(
transport, PeerId::random(), Some(limit));
assert_eq!(swarm.incoming_limit(), Some(limit));
swarm.listen_on("/memory".parse().unwrap()).unwrap();
swarm.listen_on("/memory/0".parse().unwrap()).unwrap();
assert_eq!(swarm.incoming_negotiated().count(), 0);
let swarm = Arc::new(Mutex::new(swarm));

View File

@ -20,62 +20,98 @@
use crate::{Transport, transport::TransportError};
use bytes::{Bytes, IntoBuf};
use futures::{future::{self, FutureResult}, prelude::*, stream, sync::mpsc};
use fnv::FnvHashMap;
use futures::{future::{self, FutureResult}, prelude::*, sync::mpsc, try_ready};
use lazy_static::lazy_static;
use multiaddr::{Protocol, Multiaddr};
use parking_lot::Mutex;
use rw_stream_sink::RwStreamSink;
use std::{error, fmt, sync::Arc};
use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64};
/// Builds a new pair of `Transport`s. The dialer can reach the listener by dialing `/memory`.
#[inline]
pub fn connector() -> (Dialer, Listener) {
let (tx, rx) = mpsc::unbounded();
(Dialer(tx), Listener(Arc::new(Mutex::new(rx))))
lazy_static! {
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::UnboundedSender<Channel<Bytes>>>> = Mutex::new(FnvHashMap::default());
}
/// Same as `connector()`, but allows customizing the type used for transmitting packets between
/// the two endpoints.
#[inline]
pub fn connector_custom_type<T>() -> (Dialer<T>, Listener<T>) {
let (tx, rx) = mpsc::unbounded();
(Dialer(tx), Listener(Arc::new(Mutex::new(rx))))
}
/// Transport that supports `/memory/N` multiaddresses.
#[derive(Debug, Copy, Clone, Default)]
pub struct MemoryTransport;
/// Dialing end of the memory transport.
pub struct Dialer<T = Bytes>(mpsc::UnboundedSender<Chan<T>>);
impl<T> Clone for Dialer<T> {
fn clone(&self) -> Self {
Dialer(self.0.clone())
}
}
impl<T: IntoBuf + Send + 'static> Transport for Dialer<T> {
type Output = Channel<T>;
impl Transport for MemoryTransport {
type Output = Channel<Bytes>;
type Error = MemoryTransportError;
type Listener = Box<dyn Stream<Item=(Self::ListenerUpgrade, Multiaddr), Error=MemoryTransportError> + Send>;
type ListenerUpgrade = FutureResult<Self::Output, MemoryTransportError>;
type Dial = Box<dyn Future<Item=Self::Output, Error=MemoryTransportError> + Send>;
type Listener = Listener;
type ListenerUpgrade = FutureResult<Self::Output, Self::Error>;
type Dial = FutureResult<Self::Output, Self::Error>;
fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError<Self::Error>> {
Err(TransportError::MultiaddrNotSupported(addr))
let port = if let Ok(port) = parse_memory_addr(&addr) {
port
} else {
return Err(TransportError::MultiaddrNotSupported(addr));
};
let mut hub = (&*HUB).lock();
let port = if let Some(port) = NonZeroU64::new(port) {
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};
let actual_addr = Protocol::Memory(port.get()).into();
let (tx, rx) = mpsc::unbounded();
match hub.entry(port) {
Entry::Occupied(_) => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
Entry::Vacant(e) => e.insert(tx),
};
let listener = Listener {
port,
receiver: rx,
};
Ok((listener, actual_addr))
}
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
if !is_memory_addr(&addr) {
return Err(TransportError::MultiaddrNotSupported(addr))
let port = if let Ok(port) = parse_memory_addr(&addr) {
if let Some(port) = NonZeroU64::new(port) {
port
} else {
return Err(TransportError::Other(MemoryTransportError::Unreachable));
}
} else {
return Err(TransportError::MultiaddrNotSupported(addr));
};
let hub = HUB.lock();
let chan = if let Some(tx) = hub.get(&port) {
let (a_tx, a_rx) = mpsc::unbounded();
let (b_tx, b_rx) = mpsc::unbounded();
let a = Chan { incoming: a_rx, outgoing: b_tx };
let b = Chan { incoming: b_rx, outgoing: a_tx };
let future = self.0.send(b)
.map(move |_| a.into())
.map_err(|_| MemoryTransportError::RemoteClosed);
Ok(Box::new(future))
let a = RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx });
let b = RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx });
if tx.unbounded_send(b).is_err() {
return Err(TransportError::Other(MemoryTransportError::Unreachable));
}
a
} else {
return Err(TransportError::Other(MemoryTransportError::Unreachable));
};
Ok(future::ok(chan))
}
fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
// TODO: NAT traversal for `/memory` addresses? how does that make sense?
if server == observed {
Some(server.clone())
} else {
@ -87,76 +123,69 @@ impl<T: IntoBuf + Send + 'static> Transport for Dialer<T> {
/// Error that can be produced from the `MemoryTransport`.
#[derive(Debug, Copy, Clone)]
pub enum MemoryTransportError {
/// The other side of the transport has been closed earlier.
RemoteClosed,
/// There's no listener on the given port.
Unreachable,
/// Tries to listen on a port that is already in use.
AlreadyInUse,
}
impl fmt::Display for MemoryTransportError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
MemoryTransportError::RemoteClosed =>
write!(f, "The other side of the memory transport has been closed."),
MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
}
}
}
impl error::Error for MemoryTransportError {}
/// Receiving end of the memory transport.
pub struct Listener<T = Bytes>(Arc<Mutex<mpsc::UnboundedReceiver<Chan<T>>>>);
impl<T> Clone for Listener<T> {
fn clone(&self) -> Self {
Listener(self.0.clone())
}
/// Listener for memory connections.
pub struct Listener {
/// Port we're listening on.
port: NonZeroU64,
/// Receives incoming connections.
receiver: mpsc::UnboundedReceiver<Channel<Bytes>>,
}
impl<T: IntoBuf + Send + 'static> Transport for Listener<T> {
type Output = Channel<T>;
impl Stream for Listener {
type Item = (FutureResult<Channel<Bytes>, MemoryTransportError>, Multiaddr);
type Error = MemoryTransportError;
type Listener = Box<dyn Stream<Item=(Self::ListenerUpgrade, Multiaddr), Error=MemoryTransportError> + Send>;
type ListenerUpgrade = FutureResult<Self::Output, MemoryTransportError>;
type Dial = Box<dyn Future<Item=Self::Output, Error=MemoryTransportError> + Send>;
fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError<Self::Error>> {
if !is_memory_addr(&addr) {
return Err(TransportError::MultiaddrNotSupported(addr));
}
let addr2 = addr.clone();
let receiver = self.0.clone();
let stream = stream::poll_fn(move || receiver.lock().poll())
.map(move |channel| {
(future::ok(channel.into()), addr.clone())
})
.map_err(|()| unreachable!());
Ok((Box::new(stream), addr2))
}
#[inline]
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
Err(TransportError::MultiaddrNotSupported(addr))
}
#[inline]
fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
if server == observed {
Some(server.clone())
} else {
None
}
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let channel = try_ready!(Ok(self.receiver.poll()
.expect("An unbounded receiver never panics; QED")));
let channel = match channel {
Some(c) => c,
None => return Ok(Async::Ready(None)),
};
let dialed_addr = Protocol::Memory(self.port.get()).into();
Ok(Async::Ready(Some((future::ok(channel), dialed_addr))))
}
}
/// Returns `true` if and only if the address is `/memory`.
fn is_memory_addr(a: &Multiaddr) -> bool {
impl Drop for Listener {
fn drop(&mut self) {
let val_in = HUB.lock().remove(&self.port);
debug_assert!(val_in.is_some());
}
}
/// If the address is `/memory/n`, returns the value of `n`.
fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
let mut iter = a.iter();
if iter.next() != Some(Protocol::Memory) {
return false;
}
let port = if let Some(Protocol::Memory(port)) = iter.next() {
port
} else {
return Err(());
};
if iter.next().is_some() {
return false;
return Err(());
}
true
Ok(port)
}
/// A channel represents an established, in-memory, logical connection between two endpoints.
@ -174,31 +203,31 @@ pub struct Chan<T = Bytes> {
impl<T> Stream for Chan<T> {
type Item = T;
type Error = MemoryTransportError;
type Error = io::Error;
#[inline]
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.incoming.poll().map_err(|()| MemoryTransportError::RemoteClosed)
self.incoming.poll().map_err(|()| io::ErrorKind::BrokenPipe.into())
}
}
impl<T> Sink for Chan<T> {
type SinkItem = T;
type SinkError = MemoryTransportError;
type SinkError = io::Error;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
self.outgoing.start_send(item).map_err(|_| MemoryTransportError::RemoteClosed)
self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into())
}
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.outgoing.poll_complete().map_err(|_| MemoryTransportError::RemoteClosed)
self.outgoing.poll_complete().map_err(|_| io::ErrorKind::BrokenPipe.into())
}
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.outgoing.close().map_err(|_| MemoryTransportError::RemoteClosed)
self.outgoing.close().map_err(|_| io::ErrorKind::BrokenPipe.into())
}
}
@ -208,3 +237,41 @@ impl<T: IntoBuf> Into<RwStreamSink<Chan<T>>> for Chan<T> {
RwStreamSink::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_memory_addr_works() {
assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
assert_eq!(parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), Err(()));
assert_eq!(parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), Err(()));
assert_eq!(parse_memory_addr(&"/memory/1234567890".parse().unwrap()), Ok(1_234_567_890));
}
#[test]
fn listening_twice() {
let transport = MemoryTransport::default();
assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok());
assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok());
let _listener = transport.listen_on("/memory/5".parse().unwrap()).unwrap();
assert!(transport.listen_on("/memory/5".parse().unwrap()).is_err());
assert!(transport.listen_on("/memory/5".parse().unwrap()).is_err());
drop(_listener);
assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok());
assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok());
}
#[test]
fn port_not_in_use() {
let transport = MemoryTransport::default();
assert!(transport.dial("/memory/5".parse().unwrap()).is_err());
let _listener = transport.listen_on("/memory/5".parse().unwrap()).unwrap();
assert!(transport.dial("/memory/5".parse().unwrap()).is_ok());
}
// TODO: test that is actually works
}

View File

@ -44,7 +44,7 @@ pub mod timeout;
pub mod upgrade;
pub use self::choice::OrTransport;
pub use self::memory::connector;
pub use self::memory::MemoryTransport;
pub use self::upgrade::Upgrade;
/// A transport is an object that can be used to produce connections by listening or dialing a

View File

@ -52,7 +52,8 @@ pub enum Protocol<'a> {
P2pWebRtcDirect,
P2pWebRtcStar,
P2pWebSocketStar,
Memory,
/// Contains the "port" to contact. Similar to TCP or UDP, 0 means "assign me a port".
Memory(u64),
Onion(Cow<'a, [u8; 10]>, u16),
P2p(Multihash),
P2pCircuit,
@ -139,7 +140,10 @@ impl<'a> Protocol<'a> {
"p2p-webrtc-star" => Ok(Protocol::P2pWebRtcStar),
"p2p-webrtc-direct" => Ok(Protocol::P2pWebRtcDirect),
"p2p-circuit" => Ok(Protocol::P2pCircuit),
"memory" => Ok(Protocol::Memory),
"memory" => {
let s = iter.next().ok_or(Error::InvalidProtocolString)?;
Ok(Protocol::Memory(s.parse()?))
}
_ => Err(Error::UnknownProtocolString)
}
}
@ -200,7 +204,12 @@ impl<'a> Protocol<'a> {
P2P_WEBRTC_DIRECT => Ok((Protocol::P2pWebRtcDirect, input)),
P2P_WEBRTC_STAR => Ok((Protocol::P2pWebRtcStar, input)),
P2P_WEBSOCKET_STAR => Ok((Protocol::P2pWebSocketStar, input)),
MEMORY => Ok((Protocol::Memory, input)),
MEMORY => {
let (data, rest) = split_at(8, input)?;
let mut rdr = Cursor::new(data);
let num = rdr.read_u64::<BigEndian>()?;
Ok((Protocol::Memory(num), rest))
}
ONION => {
let (data, rest) = split_at(12, input)?;
let port = BigEndian::read_u16(&data[10 ..]);
@ -315,7 +324,10 @@ impl<'a> Protocol<'a> {
Protocol::P2pWebRtcStar => w.write_all(encode::u32(P2P_WEBRTC_STAR, &mut buf))?,
Protocol::P2pWebRtcDirect => w.write_all(encode::u32(P2P_WEBRTC_DIRECT, &mut buf))?,
Protocol::P2pCircuit => w.write_all(encode::u32(P2P_CIRCUIT, &mut buf))?,
Protocol::Memory => w.write_all(encode::u32(MEMORY, &mut buf))?
Protocol::Memory(port) => {
w.write_all(encode::u32(MEMORY, &mut buf))?;
w.write_u64::<BigEndian>(*port)?
}
}
Ok(())
}
@ -334,7 +346,7 @@ impl<'a> Protocol<'a> {
P2pWebRtcDirect => P2pWebRtcDirect,
P2pWebRtcStar => P2pWebRtcStar,
P2pWebSocketStar => P2pWebSocketStar,
Memory => Memory,
Memory(a) => Memory(a),
Onion(addr, port) => Onion(Cow::Owned(addr.into_owned()), port),
P2p(a) => P2p(a),
P2pCircuit => P2pCircuit,
@ -365,7 +377,7 @@ impl<'a> fmt::Display for Protocol<'a> {
P2pWebRtcDirect => f.write_str("/p2p-webrtc-direct"),
P2pWebRtcStar => f.write_str("/p2p-webrtc-star"),
P2pWebSocketStar => f.write_str("/p2p-websocket-star"),
Memory => f.write_str("/memory"),
Memory(port) => write!(f, "/memory/{}", port),
Onion(addr, port) => {
let s = BASE32.encode(addr.as_ref());
write!(f, "/onion/{}:{}", s.to_lowercase(), port)

View File

@ -62,7 +62,7 @@ impl Arbitrary for Proto {
7 => Proto(P2pWebRtcDirect),
8 => Proto(P2pWebRtcStar),
9 => Proto(P2pWebSocketStar),
10 => Proto(Memory),
10 => Proto(Memory(g.gen())),
// TODO: impl Arbitrary for Multihash:
11 => Proto(P2p(multihash("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC"))),
12 => Proto(P2pCircuit),