diff --git a/multistream-select/Cargo.toml b/multistream-select/Cargo.toml index d10f3c8d..3f8471c6 100644 --- a/multistream-select/Cargo.toml +++ b/multistream-select/Cargo.toml @@ -8,6 +8,7 @@ bytes = "0.4" futures = { version = "0.1" } smallvec = "0.5" tokio-io = "0.1" +varint = { path = "../varint-rs" } [dev-dependencies] tokio-core = "0.1" diff --git a/multistream-select/src/lib.rs b/multistream-select/src/lib.rs index 0e668a3a..c9675e6c 100644 --- a/multistream-select/src/lib.rs +++ b/multistream-select/src/lib.rs @@ -122,6 +122,7 @@ extern crate bytes; extern crate futures; extern crate smallvec; extern crate tokio_io; +extern crate varint; mod dialer_select; mod error; diff --git a/multistream-select/src/protocol/dialer.rs b/multistream-select/src/protocol/dialer.rs index 2264e77b..917882b7 100644 --- a/multistream-select/src/protocol/dialer.rs +++ b/multistream-select/src/protocol/dialer.rs @@ -25,12 +25,13 @@ use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; use length_delimited::LengthDelimitedFramedRead; use protocol::DialerToListenerMessage; use protocol::ListenerToDialerMessage; - use protocol::MULTISTREAM_PROTOCOL_WITH_LF; use protocol::MultistreamSelectError; +use std::io::{Cursor, Read, BufRead}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder; use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite; +use varint; /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and /// accepts messages. @@ -149,7 +150,32 @@ impl Stream for Dialer } else { // A varint number of protocols - unimplemented!() + let mut reader = Cursor::new(frame); + let num_protocols: usize = varint::decode(reader.by_ref())?; + + let mut iter = BufRead::split(reader, b'\r'); + if !iter.next().ok_or(MultistreamSelectError::UnknownMessage)??.is_empty() { + return Err(MultistreamSelectError::UnknownMessage); + } + + let mut out = Vec::with_capacity(num_protocols); + for proto in iter.by_ref().take(num_protocols) { + let mut proto = proto?; + let poped = proto.pop(); // Pop the `\n` + if poped != Some(b'\n') { + return Err(MultistreamSelectError::UnknownMessage); + } + out.push(Bytes::from(proto)); + } + + // Making sure that the number of protocols was correct. + if iter.next().is_some() || out.len() != num_protocols { + return Err(MultistreamSelectError::UnknownMessage); + } + + return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolsListResponse { + list: out + }))); } } } diff --git a/multistream-select/src/protocol/error.rs b/multistream-select/src/protocol/error.rs index ceaebab7..431c0e3b 100644 --- a/multistream-select/src/protocol/error.rs +++ b/multistream-select/src/protocol/error.rs @@ -22,13 +22,14 @@ use std::error; use std::fmt; -use std::io::Error as IoError; +use std::io; +use varint; /// Error at the multistream-select layer of communication. #[derive(Debug)] pub enum MultistreamSelectError { /// I/O error. - IoError(IoError), + IoError(io::Error), /// The remote doesn't use the same multistream-select protocol as we do. FailedHandshake, @@ -38,15 +39,26 @@ pub enum MultistreamSelectError { /// Protocol names must always start with `/`, otherwise this error is returned. WrongProtocolName, + + /// Failure to parse variable-length integer. + // TODO: we don't include the actual error, because that would remove Send from the enum + VarintParseError(String), } -impl From for MultistreamSelectError { +impl From for MultistreamSelectError { #[inline] - fn from(err: IoError) -> MultistreamSelectError { + fn from(err: io::Error) -> MultistreamSelectError { MultistreamSelectError::IoError(err) } } +impl From for MultistreamSelectError { + #[inline] + fn from(err: varint::Error) -> MultistreamSelectError { + MultistreamSelectError::VarintParseError(err.to_string()) + } +} + impl error::Error for MultistreamSelectError { #[inline] fn description(&self) -> &str { @@ -63,6 +75,9 @@ impl error::Error for MultistreamSelectError { MultistreamSelectError::WrongProtocolName => { "protocol names must always start with `/`, otherwise this error is returned" }, + MultistreamSelectError::VarintParseError(_) => { + "failure to parse variable-length integer" + }, } } diff --git a/multistream-select/src/protocol/listener.rs b/multistream-select/src/protocol/listener.rs index 1c9274c2..d2bfd59d 100644 --- a/multistream-select/src/protocol/listener.rs +++ b/multistream-select/src/protocol/listener.rs @@ -25,12 +25,12 @@ use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; use length_delimited::LengthDelimitedFramedRead; use protocol::DialerToListenerMessage; use protocol::ListenerToDialerMessage; - use protocol::MULTISTREAM_PROTOCOL_WITH_LF; use protocol::MultistreamSelectError; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder; use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite; +use varint; /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// accepts messages. @@ -112,8 +112,22 @@ impl Sink for Listener } } - ListenerToDialerMessage::ProtocolsListResponse { list: _list } => { - unimplemented!() + ListenerToDialerMessage::ProtocolsListResponse { list } => { + let mut out_msg = varint::encode(list.len()); + for elem in list.iter() { + out_msg.push(b'\r'); + out_msg.extend_from_slice(elem); + out_msg.push(b'\n'); + } + + match self.inner.start_send(BytesMut::from(out_msg)) { + Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), + Ok(AsyncSink::NotReady(_)) => { + let m = ListenerToDialerMessage::ProtocolsListResponse { list }; + Ok(AsyncSink::NotReady(m)) + } + Err(err) => Err(err.into()), + } } } } diff --git a/multistream-select/src/tests.rs b/multistream-select/src/tests.rs index 4fecfd61..7871520d 100644 --- a/multistream-select/src/tests.rs +++ b/multistream-select/src/tests.rs @@ -147,7 +147,6 @@ fn no_protocol_found() { } #[test] -#[ignore] // TODO: not yet implemented in the listener fn select_proto_parallel() { let mut core = Core::new().unwrap();