From db6be0aa8b43fffdd7c619628cad5893b6f9bae2 Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Sun, 5 Nov 2017 12:21:34 +0100 Subject: [PATCH] Implement multistream-select --- Cargo.toml | 1 + multistream-select/Cargo.toml | 10 + multistream-select/README.md | 14 ++ multistream-select/src/dialer_select.rs | 163 +++++++++++++++ multistream-select/src/error.rs | 30 +++ multistream-select/src/lib.rs | 211 ++++++++++++++++++++ multistream-select/src/listener_select.rs | 81 ++++++++ multistream-select/src/protocol/dialer.rs | 170 ++++++++++++++++ multistream-select/src/protocol/error.rs | 24 +++ multistream-select/src/protocol/listener.rs | 174 ++++++++++++++++ multistream-select/src/protocol/mod.rs | 47 +++++ 11 files changed, 925 insertions(+) create mode 100644 multistream-select/Cargo.toml create mode 100644 multistream-select/README.md create mode 100644 multistream-select/src/dialer_select.rs create mode 100644 multistream-select/src/error.rs create mode 100644 multistream-select/src/lib.rs create mode 100644 multistream-select/src/listener_select.rs create mode 100644 multistream-select/src/protocol/dialer.rs create mode 100644 multistream-select/src/protocol/error.rs create mode 100644 multistream-select/src/protocol/listener.rs create mode 100644 multistream-select/src/protocol/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 4c5f6ca7..6424a19c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "multistream-select", "libp2p-host", "libp2p-transport", "libp2p-tcp-transport", diff --git a/multistream-select/Cargo.toml b/multistream-select/Cargo.toml new file mode 100644 index 00000000..b3893356 --- /dev/null +++ b/multistream-select/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "multistream-select" +version = "0.1.0" +authors = ["Parity Technologies "] + +[dependencies] +bytes = "0.4" +futures = { version = "0.1" } +tokio-core = "0.1" +tokio-io = "0.1" diff --git a/multistream-select/README.md b/multistream-select/README.md new file mode 100644 index 00000000..9d6b01cd --- /dev/null +++ b/multistream-select/README.md @@ -0,0 +1,14 @@ +# Multistream-select + +Multistream-select is the "main" protocol of libp2p. +Whenever a connection opens between two peers, it starts talking in `multistream-select`. + +The purpose of `multistream-select` is to choose which protocol we are going to use. As soon as +both sides agree on a given protocol, the socket immediately starts using it and multistream is no +longer relevant. + +However note that `multistream-select` is also sometimes used on top of another protocol such as +secio or multiplex. For example, two hosts can use `multistream-select` to decide to use secio, +then use `multistream-select` again (wrapped inside `secio`) to decide to use `multiplex`, then use +`multistream-select` one more time (wrapped inside `secio` and `multiplex`) to decide to use +the final actual protocol. diff --git a/multistream-select/src/dialer_select.rs b/multistream-select/src/dialer_select.rs new file mode 100644 index 00000000..223a96ca --- /dev/null +++ b/multistream-select/src/dialer_select.rs @@ -0,0 +1,163 @@ + +use ProtocolChoiceError; +use bytes::Bytes; +use futures::{Future, Sink, Stream}; +use futures::future::{result, loop_fn, Loop}; + +use protocol::Dialer; +use protocol::DialerToListenerMessage; +use protocol::ListenerToDialerMessage; +use tokio_io::{AsyncRead, AsyncWrite}; + +/// Helps selecting a protocol amongst the ones supported. +/// +/// This function expects a socket and a list of protocols. It uses the `multistream-select` +/// protocol to choose with the remote a protocol amongst the ones produced by the iterator. +/// +/// The iterator must produce a tuple of a protocol name advertised to the remote, a function that +/// checks whether a protocol name matches the protocol, and a protocol "identifier" of type `P` +/// (you decide what `P` is). The parameters of the match function are the name proposed by the +/// remote, and the protocol name that we passed (so that you don't have to clone the name). On +/// success, the function returns the identifier (of type `P`), plus the socket which now uses that +/// chosen protocol. +// TODO: remove the Box once -> impl Trait lands +#[inline] +pub fn dialer_select_proto<'a, R, I, M, P>( + inner: R, + protocols: I, +) -> Box + 'a> +where + R: AsyncRead + AsyncWrite + 'a, + I: Iterator + 'a, + M: FnMut(&Bytes, &Bytes) -> bool + 'a, + P: 'a, +{ + // We choose between the "serial" and "parallel" strategies based on the number of protocols. + if protocols.size_hint().1.map(|n| n <= 3).unwrap_or(false) { + dialer_select_proto_serial(inner, protocols.map(|(n, _, id)| (n, id))) + } else { + dialer_select_proto_parallel(inner, protocols) + } +} + +/// Helps selecting a protocol amongst the ones supported. +/// +/// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce +/// match functions, because it's not needed. +// TODO: remove the Box once -> impl Trait lands +pub fn dialer_select_proto_serial<'a, R, I, P>( + inner: R, + mut protocols: I, +) -> Box + 'a> +where + R: AsyncRead + AsyncWrite + 'a, + I: Iterator + 'a, + P: 'a, +{ + let future = Dialer::new(inner) + .from_err() + .and_then(move |dialer| { + // Similar to a `loop` keyword. + loop_fn(dialer, move |dialer| { + result(protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)) + // If the `protocols` iterator produced an element, send it to the dialer + .and_then(|(proto_name, proto_value)| { + dialer.send(DialerToListenerMessage::ProtocolRequest { name: proto_name.clone() }) + .map(|d| (d, proto_name, proto_value)) + .from_err() + }) + // Once sent, read one element from `dialer`. + .and_then(|(dialer, proto_name, proto_value)| { + dialer + .into_future() + .map(|(msg, rest)| (msg, rest, proto_name, proto_value)) + .map_err(|(e, _)| e.into()) + }) + // Once read, analyze the response. + .and_then(|(message, rest, proto_name, proto_value)| { + let message = message.ok_or(ProtocolChoiceError::UnexpectedMessage)?; + + match message { + ListenerToDialerMessage::ProtocolAck { ref name } + if name == &proto_name => + { + // Satisfactory response, break the loop. + Ok(Loop::Break((proto_value, rest.into_inner()))) + }, + ListenerToDialerMessage::NotAvailable => { + Ok(Loop::Continue(rest)) + }, + _ => Err(ProtocolChoiceError::UnexpectedMessage) + } + }) + }) + }); + + // The "Rust doesn't have impl Trait yet" tax. + Box::new(future) +} + +/// Helps selecting a protocol amongst the ones supported. +/// +/// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then +/// chooses the most appropriate one. +// TODO: remove the Box once -> impl Trait lands +pub fn dialer_select_proto_parallel<'a, R, I, M, P>( + inner: R, + protocols: I, +) -> Box + 'a> +where + R: AsyncRead + AsyncWrite + 'a, + I: Iterator + 'a, + M: FnMut(&Bytes, &Bytes) -> bool + 'a, + P: 'a, +{ + let future = Dialer::new(inner) + .from_err() + .and_then( + move |dialer| dialer.send(DialerToListenerMessage::ProtocolsListRequest).from_err(), + ) + .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e.into())) + .and_then(move |(msg, dialer)| { + let list = match msg { + Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list, + _ => return Err(ProtocolChoiceError::UnexpectedMessage), + }; + + let mut found = None; + for (local_name, mut match_fn, ident) in protocols { + for remote_name in &list { + if match_fn(remote_name, &local_name) { + found = Some((remote_name.clone(), ident)); + break; + } + } + + if found.is_some() { + break; + } + } + + let (proto_name, proto_val) = found.ok_or(ProtocolChoiceError::NoProtocolFound)?; + Ok((proto_name, proto_val, dialer)) + }) + .and_then(|(proto_name, proto_val, dialer)| { + dialer.send(DialerToListenerMessage::ProtocolRequest { name: proto_name.clone() }) + .from_err() + .map(|dialer| (proto_name, proto_val, dialer)) + }) + .and_then(|(proto_name, proto_val, dialer)| { + dialer.into_future() + .map(|(msg, rest)| (proto_name, proto_val, msg, rest)) + .map_err(|(err, _)| err.into()) + }) + .and_then(|(proto_name, proto_val, msg, dialer)| match msg { + Some(ListenerToDialerMessage::ProtocolAck { ref name }) if name == &proto_name => { + Ok((proto_val, dialer.into_inner())) + } + _ => Err(ProtocolChoiceError::UnexpectedMessage), + }); + + // The "Rust doesn't have impl Trait yet" tax. + Box::new(future) +} diff --git a/multistream-select/src/error.rs b/multistream-select/src/error.rs new file mode 100644 index 00000000..2a8de534 --- /dev/null +++ b/multistream-select/src/error.rs @@ -0,0 +1,30 @@ + +use protocol::MultistreamSelectError; +use std::io::Error as IoError; + +/// Error that can happen when negociating a protocol with the remote. +#[derive(Debug)] +pub enum ProtocolChoiceError { + /// Error in the protocol. + MultistreamSelectError(MultistreamSelectError), + + /// Received a message from the remote that makes no sense in the current context. + UnexpectedMessage, + + /// We don't support any protocol in common with the remote. + NoProtocolFound, +} + +impl From for ProtocolChoiceError { + #[inline] + fn from(err: MultistreamSelectError) -> ProtocolChoiceError { + ProtocolChoiceError::MultistreamSelectError(err) + } +} + +impl From for ProtocolChoiceError { + #[inline] + fn from(err: IoError) -> ProtocolChoiceError { + MultistreamSelectError::from(err).into() + } +} diff --git a/multistream-select/src/lib.rs b/multistream-select/src/lib.rs new file mode 100644 index 00000000..9a37c3d7 --- /dev/null +++ b/multistream-select/src/lib.rs @@ -0,0 +1,211 @@ +extern crate bytes; +extern crate futures; +extern crate tokio_core; +extern crate tokio_io; + +mod dialer_select; +mod error; +mod listener_select; + +pub mod protocol; + +pub use self::dialer_select::dialer_select_proto; +pub use self::error::ProtocolChoiceError; +pub use self::listener_select::listener_select_proto; + +#[cfg(test)] +mod tests { + use {listener_select_proto, dialer_select_proto}; + use ProtocolChoiceError; + use bytes::Bytes; + use dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; + use futures::{Sink, Stream}; + use futures::Future; + use protocol::{Dialer, Listener, DialerToListenerMessage, ListenerToDialerMessage}; + use tokio_core::net::TcpListener; + use tokio_core::net::TcpStream; + use tokio_core::reactor::Core; + + #[test] + fn negociate_with_self_succeeds() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming() + .into_future() + .map_err(|(e, _)| e.into()) + .and_then(move |(connec, _)| Listener::new(connec.unwrap().0)) + .and_then(|l| l.into_future().map_err(|(e, _)| e)) + .and_then(|(msg, rest)| { + let proto = match msg { + Some(DialerToListenerMessage::ProtocolRequest { name }) => name, + _ => panic!(), + }; + rest.send(ListenerToDialerMessage::ProtocolAck { name: proto }) + }); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |stream| Dialer::new(stream)) + .and_then(move |dialer| { + let p = Bytes::from("/hello/1.0.0"); + dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) + }) + .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e)) + .and_then(move |(msg, _)| { + let proto = match msg { + Some(ListenerToDialerMessage::ProtocolAck { name }) => name, + _ => panic!(), + }; + assert_eq!(proto, "/hello/1.0.0"); + Ok(()) + }); + + core.run(server.join(client)).unwrap(); + } + + #[test] + fn select_proto_basic() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming() + .into_future() + .map(|s| s.0.unwrap().0) + .map_err(|(e, _)| e.into()) + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto1"), ::eq, 0), + (Bytes::from("/proto2"), ::eq, 1), + ] + .into_iter(); + listener_select_proto(connec, protos).map(|r| r.0) + }); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto3"), ::eq, 2), + (Bytes::from("/proto2"), ::eq, 3), + ] + .into_iter(); + dialer_select_proto(connec, protos).map(|r| r.0) + }); + + let (dialer_chosen, listener_chosen) = core.run(client.join(server)).unwrap(); + assert_eq!(dialer_chosen, 3); + assert_eq!(listener_chosen, 1); + } + + #[test] + fn no_protocol_found() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming() + .into_future() + .map(|s| s.0.unwrap().0) + .map_err(|(e, _)| e.into()) + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto1"), ::eq, 1), + (Bytes::from("/proto2"), ::eq, 2), + ] + .into_iter(); + listener_select_proto(connec, protos).map(|r| r.0) + }); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto3"), ::eq, 3), + (Bytes::from("/proto4"), ::eq, 4), + ] + .into_iter(); + dialer_select_proto(connec, protos).map(|r| r.0) + }); + + match core.run(client.join(server)) { + Err(ProtocolChoiceError::NoProtocolFound) => (), + _ => panic!(), + } + } + + #[test] + #[ignore] // TODO: not yet implemented in the listener + fn select_proto_parallel() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming() + .into_future() + .map(|s| s.0.unwrap().0) + .map_err(|(e, _)| e.into()) + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto1"), ::eq, 0), + (Bytes::from("/proto2"), ::eq, 1), + ] + .into_iter(); + listener_select_proto(connec, protos).map(|r| r.0) + }); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto3"), ::eq, 2), + (Bytes::from("/proto2"), ::eq, 3), + ] + .into_iter(); + dialer_select_proto_parallel(connec, protos).map(|r| r.0) + }); + + let (dialer_chosen, listener_chosen) = core.run(client.join(server)).unwrap(); + assert_eq!(dialer_chosen, 3); + assert_eq!(listener_chosen, 1); + } + + #[test] + fn select_proto_serial() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming() + .into_future() + .map(|s| s.0.unwrap().0) + .map_err(|(e, _)| e.into()) + .and_then(move |connec| { + let protos = vec![ + (Bytes::from("/proto1"), ::eq, 0), + (Bytes::from("/proto2"), ::eq, 1), + ] + .into_iter(); + listener_select_proto(connec, protos).map(|r| r.0) + }); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |connec| { + let protos = vec![(Bytes::from("/proto3"), 2), (Bytes::from("/proto2"), 3)] + .into_iter(); + dialer_select_proto_serial(connec, protos).map(|r| r.0) + }); + + let (dialer_chosen, listener_chosen) = core.run(client.join(server)).unwrap(); + assert_eq!(dialer_chosen, 3); + assert_eq!(listener_chosen, 1); + } +} diff --git a/multistream-select/src/listener_select.rs b/multistream-select/src/listener_select.rs new file mode 100644 index 00000000..12da6937 --- /dev/null +++ b/multistream-select/src/listener_select.rs @@ -0,0 +1,81 @@ + +use ProtocolChoiceError; +use bytes::Bytes; +use futures::{Future, Sink, Stream}; +use futures::future::{err, loop_fn, Loop}; + +use protocol::DialerToListenerMessage; +use protocol::Listener; +use protocol::ListenerToDialerMessage; +use tokio_io::{AsyncRead, AsyncWrite}; + +/// Helps selecting a protocol amongst the ones supported. +/// +/// This function expects a socket and an iterator of the list of supported protocols. The iterator +/// must be clonable (ie. iterable multiple times), because the list may need to be accessed +/// multiple times. +/// +/// The iterator must produce tuples of the name of the protocol that is advertised to the remote, +/// a function that will check whether a remote protocol matches ours, and an identifier for the +/// protocol of type `P` (you decide what `P` is). The parameters of the function are the name +/// proposed by the remote, and the protocol name that we passed (so that you don't have to clone +/// the name). +/// +/// On success, returns the socket and the identifier of the chosen protocol (of type `P`). The +/// socket now uses this protocol. +// TODO: remove the Box once -> impl Trait lands +pub fn listener_select_proto<'a, R, I, M, P>( + inner: R, + protocols: I, +) -> Box + 'a> +where + R: AsyncRead + AsyncWrite + 'a, + I: Iterator + Clone + 'a, + M: FnMut(&Bytes, &Bytes) -> bool + 'a, + P: 'a, +{ + let future = Listener::new(inner).from_err().and_then(move |listener| { + + loop_fn(listener, move |listener| { + let protocols = protocols.clone(); + + listener.into_future() + .map_err(|(e, _)| e.into()) + .and_then(move |(message, listener)| match message { + Some(DialerToListenerMessage::ProtocolsListRequest) => { + let msg = ListenerToDialerMessage::ProtocolsListResponse { + list: protocols.map(|(p, _, _)| p).collect(), + }; + let fut = listener.send(msg).from_err().map(move |listener| (None, listener)); + Box::new(fut) as Box> + } + Some(DialerToListenerMessage::ProtocolRequest { name }) => { + let mut outcome = None; + let mut send_back = ListenerToDialerMessage::NotAvailable; + for (supported, mut matches, value) in protocols { + if matches(&name, &supported) { + send_back = ListenerToDialerMessage::ProtocolAck { name: name.clone() }; + outcome = Some(value); + break; + } + } + + let fut = listener.send(send_back) + .from_err() + .map(move |listener| (outcome, listener)); + Box::new(fut) as Box> + } + None => { + Box::new(err(ProtocolChoiceError::NoProtocolFound)) as Box<_> + } + }) + .map(|(outcome, listener): (_, Listener)| match outcome { + Some(outcome) => Loop::Break((outcome, listener.into_inner())), + None => Loop::Continue(listener), + }) + }) + }); + + // The "Rust doesn't have impl Trait yet" tax. + Box::new(future) +} diff --git a/multistream-select/src/protocol/dialer.rs b/multistream-select/src/protocol/dialer.rs new file mode 100644 index 00000000..8c9c5603 --- /dev/null +++ b/multistream-select/src/protocol/dialer.rs @@ -0,0 +1,170 @@ +use bytes::BytesMut; +use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use protocol::DialerToListenerMessage; +use protocol::ListenerToDialerMessage; + +use protocol::MULTISTREAM_PROTOCOL_WITH_LF; +use protocol::MultistreamSelectError; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder; +use tokio_io::codec::length_delimited::Framed as LengthDelimitedFramed; + +/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and +/// accepts messages. +pub struct Dialer { + inner: LengthDelimitedFramed, + handshake_finished: bool, +} + +impl Dialer +where + R: AsyncRead + AsyncWrite, +{ + /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the + /// future returns a `Dialer`. + pub fn new<'a>(inner: R) -> Box, Error = MultistreamSelectError> + 'a> + where + R: 'a, + { + // TODO: use Jack's lib instead + let inner = LengthDelimitedBuilder::new().length_field_length(1).new_framed(inner); + + let future = + inner.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)).from_err().map(|inner| { + Dialer { + inner: inner, + handshake_finished: false, + } + }); + Box::new(future) + } + + /// Grants back the socket. Typically used after a `ProtocolAck` has been received. + #[inline] + pub fn into_inner(self) -> R { + self.inner.into_inner() + } +} + +impl Sink for Dialer +where + R: AsyncRead + AsyncWrite, +{ + type SinkItem = DialerToListenerMessage; + type SinkError = MultistreamSelectError; + + fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + match item { + DialerToListenerMessage::ProtocolRequest { name } => { + if !name.starts_with(b"/") { + return Err(MultistreamSelectError::WrongProtocolName); + } + let mut protocol = BytesMut::from(name); + protocol.extend_from_slice(&[b'\n']); + match self.inner.start_send(protocol) { + Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), + Ok(AsyncSink::NotReady(mut protocol)) => { + let protocol_len = protocol.len(); + protocol.truncate(protocol_len - 1); + let protocol = protocol.freeze(); + Ok(AsyncSink::NotReady( + DialerToListenerMessage::ProtocolRequest { name: protocol }, + )) + } + Err(err) => Err(err.into()), + } + } + + DialerToListenerMessage::ProtocolsListRequest => { + match self.inner.start_send(BytesMut::from(&b"ls\n"[..])) { + Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), + Ok(AsyncSink::NotReady(_)) => { + Ok(AsyncSink::NotReady(DialerToListenerMessage::ProtocolsListRequest)) + } + Err(err) => Err(err.into()), + } + } + } + } + + #[inline] + fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + Ok(self.inner.poll_complete()?) + } +} + +impl Stream for Dialer +where + R: AsyncRead + AsyncWrite, +{ + type Item = ListenerToDialerMessage; + type Error = MultistreamSelectError; + + fn poll(&mut self) -> Poll, Self::Error> { + loop { + let frame = match self.inner.poll() { + Ok(Async::Ready(Some(frame))) => frame, + Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => return Err(err.into()), + }; + + if !self.handshake_finished { + if frame == MULTISTREAM_PROTOCOL_WITH_LF { + self.handshake_finished = true; + continue; + } else { + return Err(MultistreamSelectError::FailedHandshake); + } + } + + if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') { + let frame = frame.freeze(); + let protocol = frame.slice_to(frame.len() - 1); + return Ok( + Async::Ready(Some(ListenerToDialerMessage::ProtocolAck { name: protocol })), + ); + + } else if frame == &b"na\n"[..] { + return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable))); + + } else { + // A varint number of protocols + unimplemented!() + } + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::{Sink, Stream}; + use futures::Future; + use protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError}; + use tokio_core::net::{TcpListener, TcpStream}; + use tokio_core::reactor::Core; + + #[test] + fn wrong_proto_name() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming().into_future().map(|_| ()).map_err(|(e, _)| e.into()); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |stream| Dialer::new(stream)) + .and_then(move |dialer| { + let p = Bytes::from("invalid_name"); + dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) + }); + + match core.run(server.join(client)) { + Err(MultistreamSelectError::WrongProtocolName) => (), + _ => panic!(), + } + } +} diff --git a/multistream-select/src/protocol/error.rs b/multistream-select/src/protocol/error.rs new file mode 100644 index 00000000..811511a3 --- /dev/null +++ b/multistream-select/src/protocol/error.rs @@ -0,0 +1,24 @@ +use std::io::Error as IoError; + +/// Error at the multistream-select layer of communication. +#[derive(Debug)] +pub enum MultistreamSelectError { + /// I/O error. + IoError(IoError), + + /// The remote doesn't use the same multistream-select protocol as we do. + FailedHandshake, + + /// Received an unknown message from the remote. + UnknownMessage, + + /// Protocol names must always start with `/`, otherwise this error is returned. + WrongProtocolName, +} + +impl From for MultistreamSelectError { + #[inline] + fn from(err: IoError) -> MultistreamSelectError { + MultistreamSelectError::IoError(err) + } +} diff --git a/multistream-select/src/protocol/listener.rs b/multistream-select/src/protocol/listener.rs new file mode 100644 index 00000000..a801fb35 --- /dev/null +++ b/multistream-select/src/protocol/listener.rs @@ -0,0 +1,174 @@ +use bytes::BytesMut; +use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use protocol::DialerToListenerMessage; +use protocol::ListenerToDialerMessage; + +use protocol::MULTISTREAM_PROTOCOL_WITH_LF; +use protocol::MultistreamSelectError; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder; +use tokio_io::codec::length_delimited::Framed as LengthDelimitedFramed; + +/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and +/// accepts messages. +pub struct Listener { + inner: LengthDelimitedFramed, +} + +impl Listener +where + R: AsyncRead + AsyncWrite, +{ + /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the + /// future returns a `Listener`. + pub fn new<'a>(inner: R) -> Box, Error = MultistreamSelectError> + 'a> + where + R: 'a, + { + // TODO: use Jack's lib instead + let inner = LengthDelimitedBuilder::new().length_field_length(1).new_framed(inner); + + let future = inner.into_future() + .map_err(|(e, _)| e.into()) + .and_then(|(msg, rest)| { + if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) { + return Err(MultistreamSelectError::FailedHandshake); + } + Ok(rest) + }) + .and_then(|socket| { + socket.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)).from_err() + }) + .map(|inner| Listener { inner: inner }); + + Box::new(future) + } + + /// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a + /// `ProtocolAck` has been sent back. + #[inline] + pub fn into_inner(self) -> R { + self.inner.into_inner() + } +} + +impl Sink for Listener +where + R: AsyncRead + AsyncWrite, +{ + type SinkItem = ListenerToDialerMessage; + type SinkError = MultistreamSelectError; + + #[inline] + fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + match item { + ListenerToDialerMessage::ProtocolAck { name } => { + if !name.starts_with(b"/") { + return Err(MultistreamSelectError::WrongProtocolName); + } + let mut protocol = BytesMut::from(name); + protocol.extend_from_slice(&[b'\n']); + match self.inner.start_send(protocol) { + Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), + Ok(AsyncSink::NotReady(mut protocol)) => { + let protocol_len = protocol.len(); + protocol.truncate(protocol_len - 1); + let protocol = protocol.freeze(); + Ok( + AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck { name: protocol }), + ) + } + Err(err) => Err(err.into()), + } + } + + ListenerToDialerMessage::NotAvailable => { + match self.inner.start_send(BytesMut::from(&b"na\n"[..])) { + Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), + Ok(AsyncSink::NotReady(_)) => { + Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable)) + } + Err(err) => Err(err.into()), + } + } + + ListenerToDialerMessage::ProtocolsListResponse { list } => { + unimplemented!() + } + } + } + + #[inline] + fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + Ok(self.inner.poll_complete()?) + } +} + +impl Stream for Listener +where + R: AsyncRead + AsyncWrite, +{ + type Item = DialerToListenerMessage; + type Error = MultistreamSelectError; + + fn poll(&mut self) -> Poll, Self::Error> { + loop { + let frame = match self.inner.poll() { + Ok(Async::Ready(Some(frame))) => frame, + Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => return Err(err.into()), + }; + + if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') { + let frame = frame.freeze(); + let protocol = frame.slice_to(frame.len() - 1); + return Ok(Async::Ready( + Some(DialerToListenerMessage::ProtocolRequest { name: protocol }), + )); + + } else if frame == &b"ls\n"[..] { + return Ok(Async::Ready(Some(DialerToListenerMessage::ProtocolsListRequest))); + + } else { + return Err(MultistreamSelectError::UnknownMessage); + } + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::{Sink, Stream}; + use futures::Future; + use protocol::{Dialer, Listener, ListenerToDialerMessage, MultistreamSelectError}; + use tokio_core::net::{TcpListener, TcpStream}; + use tokio_core::reactor::Core; + + #[test] + fn wrong_proto_name() { + let mut core = Core::new().unwrap(); + + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = listener.incoming() + .into_future() + .map_err(|(e, _)| e.into()) + .and_then(move |(connec, _)| Listener::new(connec.unwrap().0)) + .and_then(|listener| { + let proto_name = Bytes::from("invalid-proto"); + listener.send(ListenerToDialerMessage::ProtocolAck { name: proto_name }) + }); + + let client = TcpStream::connect(&listener_addr, &core.handle()) + .from_err() + .and_then(move |stream| Dialer::new(stream)); + + match core.run(server.join(client)) { + Err(MultistreamSelectError::WrongProtocolName) => (), + _ => panic!(), + } + } +} diff --git a/multistream-select/src/protocol/mod.rs b/multistream-select/src/protocol/mod.rs new file mode 100644 index 00000000..427dc13a --- /dev/null +++ b/multistream-select/src/protocol/mod.rs @@ -0,0 +1,47 @@ +//! Contains lower-level structs to handle the multistream protocol. + +use bytes::Bytes; + +mod dialer; +mod error; +mod listener; + +const MULTISTREAM_PROTOCOL_WITH_LF: &'static [u8] = b"/multistream/1.0.0\n"; + +pub use self::dialer::Dialer; +pub use self::error::MultistreamSelectError; +pub use self::listener::Listener; + +/// Message sent from the dialer to the listener. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DialerToListenerMessage { + /// The dialer wants us to use a protocol. + /// + /// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start + /// communicating in the new protocol. + ProtocolRequest { + /// Name of the protocol. + name: Bytes, + }, + + /// The dialer requested the list of protocols that the listener supports. + ProtocolsListRequest, +} + +/// Message sent from the listener to the dialer. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ListenerToDialerMessage { + /// The protocol requested by the dialer is accepted. The socket immediately starts using the + /// new protocol. + ProtocolAck { name: Bytes }, + + /// The protocol requested by the dialer is not supported or available. + NotAvailable, + + /// Response to the request for the list of protocols. + ProtocolsListResponse { + /// The list of protocols. + // TODO: use some sort of iterator + list: Vec, + }, +}