diff --git a/.circleci/config.yml b/.circleci/config.yml index db8afb59..22ffd304 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -37,6 +37,9 @@ jobs: - run: name: Run tests, inside a docker image, with all features command: docker run --rm -v "/cache/cargo/registry:/usr/local/cargo/registry" -v "/cache/target:/app/target" -it rust-libp2p cargo test --all --all-features + - run: + name: Try the async-await feature + command: docker run --rm -v "/cache/cargo/registry:/usr/local/cargo/registry" -v "/cache/target:/app/target" -it rust-libp2p cargo +nightly test --package libp2p-core --all-features - save_cache: key: test-cache paths: @@ -48,7 +51,7 @@ jobs: steps: - checkout - restore_cache: - keys: + keys: - test-wasm-cache-{{ epoch }} - test-wasm-cache - run: diff --git a/Cargo.toml b/Cargo.toml index cbef55bd..c7e39870 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,3 +78,7 @@ members = [ "transports/websocket", "transports/wasm-ext" ] + +# TODO: remove after https://github.com/matthunz/futures-codec/issues/22 +[patch.crates-io] +futures_codec = { git = "https://github.com/matthunz/futures-codec" } diff --git a/core/Cargo.toml b/core/Cargo.toml index d8f44e43..f0628340 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -16,12 +16,13 @@ bytes = "0.4" ed25519-dalek = "1.0.0-pre.1" failure = "0.1" fnv = "1.0" +futures-timer = "0.3" lazy_static = "1.2" log = "0.4" multiaddr = { package = "parity-multiaddr", version = "0.5.0", path = "../misc/multiaddr" } multihash = { package = "parity-multihash", version = "0.1.0", path = "../misc/multihash" } multistream-select = { version = "0.5.0", path = "../misc/multistream-select" } -futures = "0.1" +futures-preview = { version = "0.3.0-alpha.17", features = ["compat", "io-compat"] } parking_lot = "0.8" protobuf = "2.3" quick-error = "1.2" @@ -30,8 +31,6 @@ rw-stream-sink = { version = "0.1.1", path = "../misc/rw-stream-sink" } libsecp256k1 = { version = "0.2.2", optional = true } sha2 = "0.8.0" smallvec = "0.6" -tokio-executor = "0.1.4" -tokio-io = "0.1" wasm-timer = "0.1" unsigned-varint = "0.2" void = "1" @@ -42,6 +41,7 @@ ring = { version = "0.14", features = ["use_heap"], default-features = false } untrusted = { version = "0.6" } [dev-dependencies] +async-std = "0.99" libp2p-swarm = { version = "0.2.0", path = "../swarm" } libp2p-tcp = { version = "0.12.0", path = "../transports/tcp" } libp2p-mplex = { version = "0.12.0", path = "../muxers/mplex" } @@ -56,4 +56,4 @@ tokio-mock-task = "0.1" [features] default = ["secp256k1"] secp256k1 = ["libsecp256k1"] - +async-await = [] diff --git a/core/src/either.rs b/core/src/either.rs index d17f8bb7..b81691a3 100644 --- a/core/src/either.rs +++ b/core/src/either.rs @@ -19,9 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{muxing::StreamMuxer, ProtocolName, transport::ListenerEvent}; -use futures::prelude::*; -use std::{fmt, io::{Error as IoError, Read, Write}}; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::{prelude::*, io::Initializer}; +use std::{fmt, io::{Error as IoError, Read, Write}, pin::Pin, task::Context, task::Poll}; #[derive(Debug, Copy, Clone)] pub enum EitherError { @@ -65,24 +64,25 @@ pub enum EitherOutput { impl AsyncRead for EitherOutput where - A: AsyncRead, - B: AsyncRead, + A: AsyncRead + Unpin, + B: AsyncRead + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + unsafe fn initializer(&self) -> Initializer { match self { - EitherOutput::First(a) => a.prepare_uninitialized_buffer(buf), - EitherOutput::Second(b) => b.prepare_uninitialized_buffer(buf), + EitherOutput::First(a) => a.initializer(), + EitherOutput::Second(b) => b.initializer(), } } - fn read_buf(&mut self, buf: &mut Bu) -> Poll { - match self { - EitherOutput::First(a) => a.read_buf(buf), - EitherOutput::Second(b) => b.read_buf(buf), + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + match &mut *self { + EitherOutput::First(a) => AsyncRead::poll_read(Pin::new(a), cx, buf), + EitherOutput::Second(b) => AsyncRead::poll_read(Pin::new(b), cx, buf), } } } +// TODO: remove? impl Read for EitherOutput where A: Read, @@ -98,17 +98,32 @@ where impl AsyncWrite for EitherOutput where - A: AsyncWrite, - B: AsyncWrite, + A: AsyncWrite + Unpin, + B: AsyncWrite + Unpin, { - fn shutdown(&mut self) -> Poll<(), IoError> { - match self { - EitherOutput::First(a) => a.shutdown(), - EitherOutput::Second(b) => b.shutdown(), + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + match &mut *self { + EitherOutput::First(a) => AsyncWrite::poll_write(Pin::new(a), cx, buf), + EitherOutput::Second(b) => AsyncWrite::poll_write(Pin::new(b), cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherOutput::First(a) => AsyncWrite::poll_flush(Pin::new(a), cx), + EitherOutput::Second(b) => AsyncWrite::poll_flush(Pin::new(b), cx), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherOutput::First(a) => AsyncWrite::poll_close(Pin::new(a), cx), + EitherOutput::Second(b) => AsyncWrite::poll_close(Pin::new(b), cx), } } } +// TODO: remove? impl Write for EitherOutput where A: Write, @@ -131,46 +146,53 @@ where impl Stream for EitherOutput where - A: Stream, - B: Stream, + A: TryStream + Unpin, + B: TryStream + Unpin, { - type Item = I; - type Error = EitherError; + type Item = Result>; - fn poll(&mut self) -> Poll, Self::Error> { - match self { - EitherOutput::First(a) => a.poll().map_err(EitherError::A), - EitherOutput::Second(b) => b.poll().map_err(EitherError::B), + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherOutput::First(a) => TryStream::try_poll_next(Pin::new(a), cx) + .map(|v| v.map(|r| r.map_err(EitherError::A))), + EitherOutput::Second(b) => TryStream::try_poll_next(Pin::new(b), cx) + .map(|v| v.map(|r| r.map_err(EitherError::B))), } } } -impl Sink for EitherOutput +impl Sink for EitherOutput where - A: Sink, - B: Sink, + A: Sink + Unpin, + B: Sink + Unpin, { - type SinkItem = I; - type SinkError = EitherError; + type Error = EitherError; - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match self { - EitherOutput::First(a) => a.start_send(item).map_err(EitherError::A), - EitherOutput::Second(b) => b.start_send(item).map_err(EitherError::B), + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherOutput::First(a) => Sink::poll_ready(Pin::new(a), cx).map_err(EitherError::A), + EitherOutput::Second(b) => Sink::poll_ready(Pin::new(b), cx).map_err(EitherError::B), } } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - match self { - EitherOutput::First(a) => a.poll_complete().map_err(EitherError::A), - EitherOutput::Second(b) => b.poll_complete().map_err(EitherError::B), + fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + match &mut *self { + EitherOutput::First(a) => Sink::start_send(Pin::new(a), item).map_err(EitherError::A), + EitherOutput::Second(b) => Sink::start_send(Pin::new(b), item).map_err(EitherError::B), } } - fn close(&mut self) -> Poll<(), Self::SinkError> { - match self { - EitherOutput::First(a) => a.close().map_err(EitherError::A), - EitherOutput::Second(b) => b.close().map_err(EitherError::B), + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherOutput::First(a) => Sink::poll_flush(Pin::new(a), cx).map_err(EitherError::A), + EitherOutput::Second(b) => Sink::poll_flush(Pin::new(b), cx).map_err(EitherError::B), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherOutput::First(a) => Sink::poll_close(Pin::new(a), cx).map_err(EitherError::A), + EitherOutput::Second(b) => Sink::poll_close(Pin::new(b), cx).map_err(EitherError::B), } } } @@ -184,10 +206,10 @@ where type OutboundSubstream = EitherOutbound; type Error = IoError; - fn poll_inbound(&self) -> Poll { + fn poll_inbound(&self, cx: &mut Context) -> Poll> { match self { - EitherOutput::First(inner) => inner.poll_inbound().map(|p| p.map(EitherOutput::First)).map_err(|e| e.into()), - EitherOutput::Second(inner) => inner.poll_inbound().map(|p| p.map(EitherOutput::Second)).map_err(|e| e.into()), + EitherOutput::First(inner) => inner.poll_inbound(cx).map(|p| p.map(EitherOutput::First)).map_err(|e| e.into()), + EitherOutput::Second(inner) => inner.poll_inbound(cx).map(|p| p.map(EitherOutput::Second)).map_err(|e| e.into()), } } @@ -198,13 +220,13 @@ where } } - fn poll_outbound(&self, substream: &mut Self::OutboundSubstream) -> Poll { + fn poll_outbound(&self, cx: &mut Context, substream: &mut Self::OutboundSubstream) -> Poll> { match (self, substream) { (EitherOutput::First(ref inner), EitherOutbound::A(ref mut substream)) => { - inner.poll_outbound(substream).map(|p| p.map(EitherOutput::First)).map_err(|e| e.into()) + inner.poll_outbound(cx, substream).map(|p| p.map(EitherOutput::First)).map_err(|e| e.into()) }, (EitherOutput::Second(ref inner), EitherOutbound::B(ref mut substream)) => { - inner.poll_outbound(substream).map(|p| p.map(EitherOutput::Second)).map_err(|e| e.into()) + inner.poll_outbound(cx, substream).map(|p| p.map(EitherOutput::Second)).map_err(|e| e.into()) }, _ => panic!("Wrong API usage") } @@ -227,56 +249,56 @@ where } } - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + unsafe fn initializer(&self) -> Initializer { match self { - EitherOutput::First(ref inner) => inner.prepare_uninitialized_buffer(buf), - EitherOutput::Second(ref inner) => inner.prepare_uninitialized_buffer(buf), + EitherOutput::First(ref inner) => inner.initializer(), + EitherOutput::Second(ref inner) => inner.initializer(), } } - fn read_substream(&self, sub: &mut Self::Substream, buf: &mut [u8]) -> Poll { + fn read_substream(&self, cx: &mut Context, sub: &mut Self::Substream, buf: &mut [u8]) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { - inner.read_substream(sub, buf).map_err(|e| e.into()) + inner.read_substream(cx, sub, buf).map_err(|e| e.into()) }, (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { - inner.read_substream(sub, buf).map_err(|e| e.into()) + inner.read_substream(cx, sub, buf).map_err(|e| e.into()) }, _ => panic!("Wrong API usage") } } - fn write_substream(&self, sub: &mut Self::Substream, buf: &[u8]) -> Poll { + fn write_substream(&self, cx: &mut Context, sub: &mut Self::Substream, buf: &[u8]) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { - inner.write_substream(sub, buf).map_err(|e| e.into()) + inner.write_substream(cx, sub, buf).map_err(|e| e.into()) }, (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { - inner.write_substream(sub, buf).map_err(|e| e.into()) + inner.write_substream(cx, sub, buf).map_err(|e| e.into()) }, _ => panic!("Wrong API usage") } } - fn flush_substream(&self, sub: &mut Self::Substream) -> Poll<(), Self::Error> { + fn flush_substream(&self, cx: &mut Context, sub: &mut Self::Substream) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { - inner.flush_substream(sub).map_err(|e| e.into()) + inner.flush_substream(cx, sub).map_err(|e| e.into()) }, (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { - inner.flush_substream(sub).map_err(|e| e.into()) + inner.flush_substream(cx, sub).map_err(|e| e.into()) }, _ => panic!("Wrong API usage") } } - fn shutdown_substream(&self, sub: &mut Self::Substream) -> Poll<(), Self::Error> { + fn shutdown_substream(&self, cx: &mut Context, sub: &mut Self::Substream) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { - inner.shutdown_substream(sub).map_err(|e| e.into()) + inner.shutdown_substream(cx, sub).map_err(|e| e.into()) }, (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { - inner.shutdown_substream(sub).map_err(|e| e.into()) + inner.shutdown_substream(cx, sub).map_err(|e| e.into()) }, _ => panic!("Wrong API usage") } @@ -306,17 +328,17 @@ where } } - fn close(&self) -> Poll<(), Self::Error> { + fn close(&self, cx: &mut Context) -> Poll> { match self { - EitherOutput::First(inner) => inner.close().map_err(|e| e.into()), - EitherOutput::Second(inner) => inner.close().map_err(|e| e.into()), + EitherOutput::First(inner) => inner.close(cx).map_err(|e| e.into()), + EitherOutput::Second(inner) => inner.close(cx).map_err(|e| e.into()), } } - fn flush_all(&self) -> Poll<(), Self::Error> { + fn flush_all(&self, cx: &mut Context) -> Poll> { match self { - EitherOutput::First(inner) => inner.flush_all().map_err(|e| e.into()), - EitherOutput::Second(inner) => inner.flush_all().map_err(|e| e.into()), + EitherOutput::First(inner) => inner.flush_all(cx).map_err(|e| e.into()), + EitherOutput::Second(inner) => inner.flush_all(cx).map_err(|e| e.into()), } } } @@ -338,20 +360,25 @@ pub enum EitherListenStream { impl Stream for EitherListenStream where - AStream: Stream>, - BStream: Stream>, + AStream: TryStream> + Unpin, + BStream: TryStream> + Unpin, { - type Item = ListenerEvent>; - type Error = EitherError; + type Item = Result>, EitherError>; - fn poll(&mut self) -> Poll, Self::Error> { - match self { - EitherListenStream::First(a) => a.poll() - .map(|i| (i.map(|v| (v.map(|e| e.map(EitherFuture::First)))))) - .map_err(EitherError::A), - EitherListenStream::Second(a) => a.poll() - .map(|i| (i.map(|v| (v.map(|e| e.map(EitherFuture::Second)))))) - .map_err(EitherError::B), + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut *self { + EitherListenStream::First(a) => match TryStream::try_poll_next(Pin::new(a), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le.map(EitherFuture::First)))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), + }, + EitherListenStream::Second(a) => match TryStream::try_poll_next(Pin::new(a), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le.map(EitherFuture::Second)))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::B(err)))), + }, } } } @@ -366,16 +393,17 @@ pub enum EitherFuture { impl Future for EitherFuture where - AFuture: Future, - BFuture: Future, + AFuture: TryFuture + Unpin, + BFuture: TryFuture + Unpin, { - type Item = EitherOutput; - type Error = EitherError; + type Output = Result, EitherError>; - fn poll(&mut self) -> Poll { - match self { - EitherFuture::First(a) => a.poll().map(|v| v.map(EitherOutput::First)).map_err(EitherError::A), - EitherFuture::Second(a) => a.poll().map(|v| v.map(EitherOutput::Second)).map_err(EitherError::B), + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match &mut *self { + EitherFuture::First(a) => TryFuture::try_poll(Pin::new(a), cx) + .map_ok(EitherOutput::First).map_err(EitherError::A), + EitherFuture::Second(a) => TryFuture::try_poll(Pin::new(a), cx) + .map_ok(EitherOutput::Second).map_err(EitherError::B), } } } @@ -386,21 +414,17 @@ pub enum EitherFuture2 { A(A), B(B) } impl Future for EitherFuture2 where - AFut: Future, - BFut: Future + AFut: TryFuture + Unpin, + BFut: TryFuture + Unpin, { - type Item = EitherOutput; - type Error = EitherError; + type Output = Result, EitherError>; - fn poll(&mut self) -> Poll { - match self { - EitherFuture2::A(a) => a.poll() - .map(|v| v.map(EitherOutput::First)) - .map_err(EitherError::A), - - EitherFuture2::B(b) => b.poll() - .map(|v| v.map(EitherOutput::Second)) - .map_err(EitherError::B) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match &mut *self { + EitherFuture2::A(a) => TryFuture::try_poll(Pin::new(a), cx) + .map_ok(EitherOutput::First).map_err(EitherError::A), + EitherFuture2::B(a) => TryFuture::try_poll(Pin::new(a), cx) + .map_ok(EitherOutput::Second).map_err(EitherError::B), } } } diff --git a/core/src/lib.rs b/core/src/lib.rs index c3276415..471e928f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -18,6 +18,8 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +#![cfg_attr(feature = "async-await", feature(async_await))] + //! Transports, upgrades, multiplexing and node handling of *libp2p*. //! //! The main concepts of libp2p-core are: @@ -37,15 +39,12 @@ /// Multi-address re-export. pub use multiaddr; -pub use multistream_select::Negotiated; +pub type Negotiated = futures::compat::Compat01As03>>; mod keys_proto; mod peer_id; mod translation; -#[cfg(test)] -mod tests; - pub mod either; pub mod identity; pub mod muxing; diff --git a/core/src/muxing.rs b/core/src/muxing.rs index 28245666..0ed2068a 100644 --- a/core/src/muxing.rs +++ b/core/src/muxing.rs @@ -52,13 +52,9 @@ //! implementation of `StreamMuxer` to control everything that happens on the wire. use fnv::FnvHashMap; -use futures::{future, prelude::*, try_ready}; +use futures::{future, prelude::*, io::Initializer, task::Context, task::Poll}; use parking_lot::Mutex; -use std::io::{self, Read, Write}; -use std::ops::Deref; -use std::fmt; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{io, ops::Deref, fmt, pin::Pin, sync::atomic::{AtomicUsize, Ordering}}; pub use self::singleton::SingletonMuxer; @@ -90,12 +86,12 @@ pub trait StreamMuxer { /// /// This function behaves the same as a `Stream`. /// - /// If `NotReady` is returned, then the current task will be notified once the muxer + /// If `Pending` is returned, then the current task will be notified once the muxer /// is ready to be polled, similar to the API of `Stream::poll()`. /// Only the latest task that was used to call this method may be notified. /// /// An error can be generated if the connection has been closed. - fn poll_inbound(&self) -> Poll; + fn poll_inbound(&self, cx: &mut Context) -> Poll>; /// Opens a new outgoing substream, and produces the equivalent to a future that will be /// resolved when it becomes available. @@ -106,22 +102,23 @@ pub trait StreamMuxer { /// Polls the outbound substream. /// - /// If `NotReady` is returned, then the current task will be notified once the substream + /// If `Pending` is returned, then the current task will be notified once the substream /// is ready to be polled, similar to the API of `Future::poll()`. /// However, for each individual outbound substream, only the latest task that was used to /// call this method may be notified. /// /// May panic or produce an undefined result if an earlier polling of the same substream /// returned `Ready` or `Err`. - fn poll_outbound(&self, s: &mut Self::OutboundSubstream) -> Poll; + fn poll_outbound(&self, cx: &mut Context, s: &mut Self::OutboundSubstream) + -> Poll>; /// Destroys an outbound substream future. Use this after the outbound substream has finished, /// or if you want to interrupt it. fn destroy_outbound(&self, s: Self::OutboundSubstream); - /// Reads data from a substream. The behaviour is the same as `tokio_io::AsyncRead::poll_read`. + /// Reads data from a substream. The behaviour is the same as `futures::AsyncRead::poll_read`. /// - /// If `NotReady` is returned, then the current task will be notified once the substream + /// If `Pending` is returned, then the current task will be notified once the substream /// is ready to be read. However, for each individual substream, only the latest task that /// was used to call this method may be notified. /// @@ -130,25 +127,17 @@ pub trait StreamMuxer { /// /// An error can be generated if the connection has been closed, or if a protocol misbehaviour /// happened. - fn read_substream(&self, s: &mut Self::Substream, buf: &mut [u8]) -> Poll; + fn read_substream(&self, cx: &mut Context, s: &mut Self::Substream, buf: &mut [u8]) + -> Poll>; - /// Mimics the `prepare_uninitialized_buffer` method of the `AsyncRead` trait. - /// - /// This function isn't actually unsafe to call but unsafe to implement. The implementer must - /// ensure that either the whole buf has been zeroed or that `read_substream` overwrites the - /// buffer without reading it and returns correct value. - /// - /// If this function returns true, then the memory has been zeroed out. This allows - /// implementations of `AsyncRead` which are composed of multiple subimplementations to - /// efficiently implement `prepare_uninitialized_buffer`. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - for b in buf.iter_mut() { *b = 0; } - true + /// Mimics the `initializer` method of the `AsyncRead` trait. + unsafe fn initializer(&self) -> Initializer { + Initializer::zeroing() } - /// Write data to a substream. The behaviour is the same as `tokio_io::AsyncWrite::poll_write`. + /// Write data to a substream. The behaviour is the same as `futures::AsyncWrite::poll_write`. /// - /// If `NotReady` is returned, then the current task will be notified once the substream + /// If `Pending` is returned, then the current task will be notified once the substream /// is ready to be read. For each individual substream, only the latest task that was used to /// call this method may be notified. /// @@ -157,24 +146,26 @@ pub trait StreamMuxer { /// /// It is incorrect to call this method on a substream if you called `shutdown_substream` on /// this substream earlier. - fn write_substream(&self, s: &mut Self::Substream, buf: &[u8]) -> Poll; + fn write_substream(&self, cx: &mut Context, s: &mut Self::Substream, buf: &[u8]) + -> Poll>; - /// Flushes a substream. The behaviour is the same as `tokio_io::AsyncWrite::poll_flush`. + /// Flushes a substream. The behaviour is the same as `futures::AsyncWrite::poll_flush`. /// /// After this method has been called, data written earlier on the substream is guaranteed to /// be received by the remote. /// - /// If `NotReady` is returned, then the current task will be notified once the substream + /// If `Pending` is returned, then the current task will be notified once the substream /// is ready to be read. For each individual substream, only the latest task that was used to /// call this method may be notified. /// /// > **Note**: This method may be implemented as a call to `flush_all`. - fn flush_substream(&self, s: &mut Self::Substream) -> Poll<(), Self::Error>; + fn flush_substream(&self, cx: &mut Context, s: &mut Self::Substream) + -> Poll>; /// Attempts to shut down the writing side of a substream. The behaviour is similar to - /// `tokio_io::AsyncWrite::shutdown`. + /// `AsyncWrite::poll_close`. /// - /// Contrary to `AsyncWrite::shutdown`, shutting down a substream does not imply + /// Contrary to `AsyncWrite::poll_close`, shutting down a substream does not imply /// `flush_substream`. If you want to make sure that the remote is immediately informed about /// the shutdown, use `flush_substream` or `flush_all`. /// @@ -182,7 +173,8 @@ pub trait StreamMuxer { /// /// An error can be generated if the connection has been closed, or if a protocol misbehaviour /// happened. - fn shutdown_substream(&self, s: &mut Self::Substream) -> Poll<(), Self::Error>; + fn shutdown_substream(&self, cx: &mut Context, s: &mut Self::Substream) + -> Poll>; /// Destroys a substream. fn destroy_substream(&self, s: Self::Substream); @@ -197,7 +189,7 @@ pub trait StreamMuxer { /// Closes this `StreamMuxer`. /// - /// After this has returned `Ok(Async::Ready(()))`, the muxer has become useless. All + /// After this has returned `Poll::Ready(Ok(()))`, the muxer has become useless. All /// subsequent reads must return either `EOF` or an error. All subsequent writes, shutdowns, /// or polls must generate an error or be ignored. /// @@ -207,14 +199,14 @@ pub trait StreamMuxer { /// > that the remote is properly informed of the shutdown. However, apart from /// > properly informing the remote, there is no difference between this and /// > immediately dropping the muxer. - fn close(&self) -> Poll<(), Self::Error>; + fn close(&self, cx: &mut Context) -> Poll>; /// Flush this `StreamMuxer`. /// /// This drains any write buffers of substreams and delivers any pending shutdown notifications /// due to `shutdown_substream` or `close`. One may thus shutdown groups of substreams /// followed by a final `flush_all` instead of having to do `flush_substream` for each. - fn flush_all(&self) -> Poll<(), Self::Error>; + fn flush_all(&self, cx: &mut Context) -> Poll>; } /// Polls for an inbound from the muxer but wraps the output in an object that @@ -222,14 +214,14 @@ pub trait StreamMuxer { #[inline] pub fn inbound_from_ref_and_wrap

( muxer: P, -) -> impl Future, Error = ::Error> +) -> impl Future, ::Error>> where P: Deref + Clone, P::Target: StreamMuxer, { let muxer2 = muxer.clone(); - future::poll_fn(move || muxer.poll_inbound()) - .map(|substream| substream_from_ref(muxer2, substream)) + future::poll_fn(move |cx| muxer.poll_inbound(cx)) + .map_ok(|substream| substream_from_ref(muxer2, substream)) } /// Same as `outbound_from_ref`, but wraps the output in an object that @@ -258,17 +250,16 @@ where P: Deref + Clone, P::Target: StreamMuxer, { - type Item = SubstreamRef

; - type Error = ::Error; + type Output = Result, ::Error>; - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(substream)) => { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match Future::poll(Pin::new(&mut self.inner), cx) { + Poll::Ready(Ok(substream)) => { let out = substream_from_ref(self.inner.muxer.clone(), substream); - Ok(Async::Ready(out)) + Poll::Ready(Ok(out)) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err(err), + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), } } } @@ -297,18 +288,26 @@ where outbound: Option<::OutboundSubstream>, } +impl

Unpin for OutboundSubstreamRefFuture

+where + P: Deref, + P::Target: StreamMuxer, +{ +} + impl

Future for OutboundSubstreamRefFuture

where P: Deref, P::Target: StreamMuxer, { - type Item = ::Substream; - type Error = ::Error; + type Output = Result<::Substream, ::Error>; #[inline] - fn poll(&mut self) -> Poll { - self.muxer - .poll_outbound(self.outbound.as_mut().expect("outbound was empty")) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // We use a `this` because the compiler isn't smart enough to allow mutably borrowing + // multiple different fields from the `Pin` at the same time. + let this = &mut *self; + this.muxer.poll_outbound(cx, this.outbound.as_mut().expect("outbound was empty")) } } @@ -370,20 +369,11 @@ where } } - -impl

Read for SubstreamRef

+impl

Unpin for SubstreamRef

where P: Deref, P::Target: StreamMuxer, { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> Result { - let s = self.substream.as_mut().expect("substream was empty"); - match self.muxer.read_substream(s, buf).map_err(|e| e.into())? { - Async::Ready(n) => Ok(n), - Async::NotReady => Err(io::ErrorKind::WouldBlock.into()) - } - } } impl

AsyncRead for SubstreamRef

@@ -391,37 +381,17 @@ where P: Deref, P::Target: StreamMuxer, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.muxer.prepare_uninitialized_buffer(buf) + unsafe fn initializer(&self) -> Initializer { + self.muxer.initializer() } - fn poll_read(&mut self, buf: &mut [u8]) -> Poll { - let s = self.substream.as_mut().expect("substream was empty"); - self.muxer.read_substream(s, buf).map_err(|e| e.into()) - } -} + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + // We use a `this` because the compiler isn't smart enough to allow mutably borrowing + // multiple different fields from the `Pin` at the same time. + let this = &mut *self; -impl

Write for SubstreamRef

-where - P: Deref, - P::Target: StreamMuxer, -{ - #[inline] - fn write(&mut self, buf: &[u8]) -> Result { - let s = self.substream.as_mut().expect("substream was empty"); - match self.muxer.write_substream(s, buf).map_err(|e| e.into())? { - Async::Ready(n) => Ok(n), - Async::NotReady => Err(io::ErrorKind::WouldBlock.into()) - } - } - - #[inline] - fn flush(&mut self) -> Result<(), io::Error> { - let s = self.substream.as_mut().expect("substream was empty"); - match self.muxer.flush_substream(s).map_err(|e| e.into())? { - Async::Ready(()) => Ok(()), - Async::NotReady => Err(io::ErrorKind::WouldBlock.into()) - } + let s = this.substream.as_mut().expect("substream was empty"); + this.muxer.read_substream(cx, s, buf).map_err(|e| e.into()) } } @@ -430,36 +400,51 @@ where P: Deref, P::Target: StreamMuxer, { - #[inline] - fn poll_write(&mut self, buf: &[u8]) -> Poll { - let s = self.substream.as_mut().expect("substream was empty"); - self.muxer.write_substream(s, buf).map_err(|e| e.into()) + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + // We use a `this` because the compiler isn't smart enough to allow mutably borrowing + // multiple different fields from the `Pin` at the same time. + let this = &mut *self; + + let s = this.substream.as_mut().expect("substream was empty"); + this.muxer.write_substream(cx, s, buf).map_err(|e| e.into()) } - #[inline] - fn shutdown(&mut self) -> Poll<(), io::Error> { - let s = self.substream.as_mut().expect("substream was empty"); + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // We use a `this` because the compiler isn't smart enough to allow mutably borrowing + // multiple different fields from the `Pin` at the same time. + let this = &mut *self; + + let s = this.substream.as_mut().expect("substream was empty"); loop { - match self.shutdown_state { + match this.shutdown_state { ShutdownState::Shutdown => { - try_ready!(self.muxer.shutdown_substream(s).map_err(|e| e.into())); - self.shutdown_state = ShutdownState::Flush; + match this.muxer.shutdown_substream(cx, s) { + Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Flush, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending, + } } ShutdownState::Flush => { - try_ready!(self.muxer.flush_substream(s).map_err(|e| e.into())); - self.shutdown_state = ShutdownState::Done; + match this.muxer.flush_substream(cx, s) { + Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Done, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending, + } } ShutdownState::Done => { - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); } } } } - #[inline] - fn poll_flush(&mut self) -> Poll<(), io::Error> { - let s = self.substream.as_mut().expect("substream was empty"); - self.muxer.flush_substream(s).map_err(|e| e.into()) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // We use a `this` because the compiler isn't smart enough to allow mutably borrowing + // multiple different fields from the `Pin` at the same time. + let this = &mut *self; + + let s = this.substream.as_mut().expect("substream was empty"); + this.muxer.flush_substream(cx, s).map_err(|e| e.into()) } } @@ -507,8 +492,8 @@ impl StreamMuxer for StreamMuxerBox { type Error = io::Error; #[inline] - fn poll_inbound(&self) -> Poll { - self.inner.poll_inbound() + fn poll_inbound(&self, cx: &mut Context) -> Poll> { + self.inner.poll_inbound(cx) } #[inline] @@ -517,8 +502,8 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn poll_outbound(&self, s: &mut Self::OutboundSubstream) -> Poll { - self.inner.poll_outbound(s) + fn poll_outbound(&self, cx: &mut Context, s: &mut Self::OutboundSubstream) -> Poll> { + self.inner.poll_outbound(cx, s) } #[inline] @@ -526,28 +511,28 @@ impl StreamMuxer for StreamMuxerBox { self.inner.destroy_outbound(substream) } - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + unsafe fn initializer(&self) -> Initializer { + self.inner.initializer() } #[inline] - fn read_substream(&self, s: &mut Self::Substream, buf: &mut [u8]) -> Poll { - self.inner.read_substream(s, buf) + fn read_substream(&self, cx: &mut Context, s: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + self.inner.read_substream(cx, s, buf) } #[inline] - fn write_substream(&self, s: &mut Self::Substream, buf: &[u8]) -> Poll { - self.inner.write_substream(s, buf) + fn write_substream(&self, cx: &mut Context, s: &mut Self::Substream, buf: &[u8]) -> Poll> { + self.inner.write_substream(cx, s, buf) } #[inline] - fn flush_substream(&self, s: &mut Self::Substream) -> Poll<(), Self::Error> { - self.inner.flush_substream(s) + fn flush_substream(&self, cx: &mut Context, s: &mut Self::Substream) -> Poll> { + self.inner.flush_substream(cx, s) } #[inline] - fn shutdown_substream(&self, s: &mut Self::Substream) -> Poll<(), Self::Error> { - self.inner.shutdown_substream(s) + fn shutdown_substream(&self, cx: &mut Context, s: &mut Self::Substream) -> Poll> { + self.inner.shutdown_substream(cx, s) } #[inline] @@ -556,8 +541,8 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn close(&self) -> Poll<(), Self::Error> { - self.inner.close() + fn close(&self, cx: &mut Context) -> Poll> { + self.inner.close(cx) } #[inline] @@ -566,8 +551,8 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn flush_all(&self) -> Poll<(), Self::Error> { - self.inner.flush_all() + fn flush_all(&self, cx: &mut Context) -> Poll> { + self.inner.flush_all(cx) } } @@ -588,11 +573,16 @@ where type Error = io::Error; #[inline] - fn poll_inbound(&self) -> Poll { - let substream = try_ready!(self.inner.poll_inbound().map_err(|e| e.into())); + fn poll_inbound(&self, cx: &mut Context) -> Poll> { + let substream = match self.inner.poll_inbound(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(s)) => s, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + }; + let id = self.next_substream.fetch_add(1, Ordering::Relaxed); self.substreams.lock().insert(id, substream); - Ok(Async::Ready(id)) + Poll::Ready(Ok(id)) } #[inline] @@ -606,13 +596,18 @@ where #[inline] fn poll_outbound( &self, + cx: &mut Context, substream: &mut Self::OutboundSubstream, - ) -> Poll { + ) -> Poll> { let mut list = self.outbound.lock(); - let substream = try_ready!(self.inner.poll_outbound(list.get_mut(substream).unwrap()).map_err(|e| e.into())); + let substream = match self.inner.poll_outbound(cx, list.get_mut(substream).unwrap()) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(s)) => s, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + }; let id = self.next_substream.fetch_add(1, Ordering::Relaxed); self.substreams.lock().insert(id, substream); - Ok(Async::Ready(id)) + Poll::Ready(Ok(id)) } #[inline] @@ -621,32 +616,32 @@ where self.inner.destroy_outbound(list.remove(&substream).unwrap()) } - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + unsafe fn initializer(&self) -> Initializer { + self.inner.initializer() } #[inline] - fn read_substream(&self, s: &mut Self::Substream, buf: &mut [u8]) -> Poll { + fn read_substream(&self, cx: &mut Context, s: &mut Self::Substream, buf: &mut [u8]) -> Poll> { let mut list = self.substreams.lock(); - self.inner.read_substream(list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) + self.inner.read_substream(cx, list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) } #[inline] - fn write_substream(&self, s: &mut Self::Substream, buf: &[u8]) -> Poll { + fn write_substream(&self, cx: &mut Context, s: &mut Self::Substream, buf: &[u8]) -> Poll> { let mut list = self.substreams.lock(); - self.inner.write_substream(list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) + self.inner.write_substream(cx, list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) } #[inline] - fn flush_substream(&self, s: &mut Self::Substream) -> Poll<(), Self::Error> { + fn flush_substream(&self, cx: &mut Context, s: &mut Self::Substream) -> Poll> { let mut list = self.substreams.lock(); - self.inner.flush_substream(list.get_mut(s).unwrap()).map_err(|e| e.into()) + self.inner.flush_substream(cx, list.get_mut(s).unwrap()).map_err(|e| e.into()) } #[inline] - fn shutdown_substream(&self, s: &mut Self::Substream) -> Poll<(), Self::Error> { + fn shutdown_substream(&self, cx: &mut Context, s: &mut Self::Substream) -> Poll> { let mut list = self.substreams.lock(); - self.inner.shutdown_substream(list.get_mut(s).unwrap()).map_err(|e| e.into()) + self.inner.shutdown_substream(cx, list.get_mut(s).unwrap()).map_err(|e| e.into()) } #[inline] @@ -656,8 +651,8 @@ where } #[inline] - fn close(&self) -> Poll<(), Self::Error> { - self.inner.close().map_err(|e| e.into()) + fn close(&self, cx: &mut Context) -> Poll> { + self.inner.close(cx).map_err(|e| e.into()) } #[inline] @@ -666,7 +661,7 @@ where } #[inline] - fn flush_all(&self) -> Poll<(), Self::Error> { - self.inner.flush_all().map_err(|e| e.into()) + fn flush_all(&self, cx: &mut Context) -> Poll> { + self.inner.flush_all(cx).map_err(|e| e.into()) } } diff --git a/core/src/muxing/singleton.rs b/core/src/muxing/singleton.rs index 7bec14ed..f85e22fd 100644 --- a/core/src/muxing/singleton.rs +++ b/core/src/muxing/singleton.rs @@ -19,10 +19,9 @@ // DEALINGS IN THE SOFTWARE. use crate::{Endpoint, muxing::StreamMuxer}; -use futures::prelude::*; +use futures::{prelude::*, io::Initializer}; use parking_lot::Mutex; -use std::{io, sync::atomic::{AtomicBool, Ordering}}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{io, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::Context, task::Poll}; /// Implementation of `StreamMuxer` that allows only one substream on top of a connection, /// yielding the connection itself. @@ -62,22 +61,22 @@ pub struct OutboundSubstream {} impl StreamMuxer for SingletonMuxer where - TSocket: AsyncRead + AsyncWrite, + TSocket: AsyncRead + AsyncWrite + Unpin, { type Substream = Substream; type OutboundSubstream = OutboundSubstream; type Error = io::Error; - fn poll_inbound(&self) -> Poll { + fn poll_inbound(&self, _: &mut Context) -> Poll> { match self.endpoint { - Endpoint::Dialer => return Ok(Async::NotReady), + Endpoint::Dialer => return Poll::Pending, Endpoint::Listener => {} } if !self.substream_extracted.swap(true, Ordering::Relaxed) { - Ok(Async::Ready(Substream {})) + Poll::Ready(Ok(Substream {})) } else { - Ok(Async::NotReady) + Poll::Pending } } @@ -85,44 +84,44 @@ where OutboundSubstream {} } - fn poll_outbound(&self, _: &mut Self::OutboundSubstream) -> Poll { + fn poll_outbound(&self, _: &mut Context, _: &mut Self::OutboundSubstream) -> Poll> { match self.endpoint { - Endpoint::Listener => return Ok(Async::NotReady), + Endpoint::Listener => return Poll::Pending, Endpoint::Dialer => {} } if !self.substream_extracted.swap(true, Ordering::Relaxed) { - Ok(Async::Ready(Substream {})) + Poll::Ready(Ok(Substream {})) } else { - Ok(Async::NotReady) + Poll::Pending } } fn destroy_outbound(&self, _: Self::OutboundSubstream) { } - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.lock().prepare_uninitialized_buffer(buf) + unsafe fn initializer(&self) -> Initializer { + self.inner.lock().initializer() } - fn read_substream(&self, _: &mut Self::Substream, buf: &mut [u8]) -> Poll { - let res = self.inner.lock().poll_read(buf); - if let Ok(Async::Ready(_)) = res { + fn read_substream(&self, cx: &mut Context, _: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + let res = AsyncRead::poll_read(Pin::new(&mut *self.inner.lock()), cx, buf); + if let Poll::Ready(Ok(_)) = res { self.remote_acknowledged.store(true, Ordering::Release); } res } - fn write_substream(&self, _: &mut Self::Substream, buf: &[u8]) -> Poll { - self.inner.lock().poll_write(buf) + fn write_substream(&self, cx: &mut Context, _: &mut Self::Substream, buf: &[u8]) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut *self.inner.lock()), cx, buf) } - fn flush_substream(&self, _: &mut Self::Substream) -> Poll<(), io::Error> { - self.inner.lock().poll_flush() + fn flush_substream(&self, cx: &mut Context, _: &mut Self::Substream) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut *self.inner.lock()), cx) } - fn shutdown_substream(&self, _: &mut Self::Substream) -> Poll<(), io::Error> { - self.inner.lock().shutdown() + fn shutdown_substream(&self, cx: &mut Context, _: &mut Self::Substream) -> Poll> { + AsyncWrite::poll_close(Pin::new(&mut *self.inner.lock()), cx) } fn destroy_substream(&self, _: Self::Substream) { @@ -132,12 +131,12 @@ where self.remote_acknowledged.load(Ordering::Acquire) } - fn close(&self) -> Poll<(), io::Error> { + fn close(&self, cx: &mut Context) -> Poll> { // The `StreamMuxer` trait requires that `close()` implies `flush_all()`. - self.flush_all() + self.flush_all(cx) } - fn flush_all(&self) -> Poll<(), io::Error> { - self.inner.lock().poll_flush() + fn flush_all(&self, cx: &mut Context) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut *self.inner.lock()), cx) } } diff --git a/core/src/nodes/collection.rs b/core/src/nodes/collection.rs index af8601d2..9e212810 100644 --- a/core/src/nodes/collection.rs +++ b/core/src/nodes/collection.rs @@ -29,11 +29,7 @@ use crate::{ }; use fnv::FnvHashMap; use futures::prelude::*; -use std::{error, fmt, hash::Hash, mem}; - -pub use crate::nodes::tasks::StartTakeOver; - -mod tests; +use std::{error, fmt, hash::Hash, mem, task::Context, task::Poll}; /// Implementation of `Stream` that handles a collection of nodes. pub struct CollectionStream { @@ -58,6 +54,9 @@ where } } +impl Unpin for + CollectionStream { } + /// State of a task. #[derive(Debug, Clone, PartialEq, Eq)] enum TaskState { @@ -323,7 +322,7 @@ where pub fn add_reach_attempt(&mut self, future: TFut, handler: THandler) -> ReachAttemptId where - TFut: Future + Send + 'static, + TFut: Future> + Unpin + Send + 'static, THandler: IntoNodeHandler + Send + 'static, THandler::Handler: NodeHandler, InEvent = TInEvent, OutEvent = TOutEvent, Error = THandlerErr> + Send + 'static, ::OutboundOpenInfo: Send + 'static, @@ -358,17 +357,19 @@ where } /// Sends an event to all nodes. - #[must_use] - pub fn start_broadcast(&mut self, event: &TInEvent) -> AsyncSink<()> + /// + /// Must be called only after a successful call to `poll_ready_broadcast`. + pub fn start_broadcast(&mut self, event: &TInEvent) where TInEvent: Clone { self.inner.start_broadcast(event) } + /// Wait until we have enough room in senders to broadcast an event. #[must_use] - pub fn complete_broadcast(&mut self) -> Async<()> { - self.inner.complete_broadcast() + pub fn poll_ready_broadcast(&mut self, cx: &mut Context) -> Poll<()> { + self.inner.poll_ready_broadcast(cx) } /// Adds an existing connection to a node to the collection. @@ -447,13 +448,13 @@ where /// > **Note**: we use a regular `poll` method instead of implementing `Stream` in order to /// > remove the `Err` variant, but also because we want the `CollectionStream` to stay /// > borrowed if necessary. - pub fn poll(&mut self) -> Async> + pub fn poll(&mut self, cx: &mut Context) -> Poll> where TConnInfo: Clone, // TODO: Clone shouldn't be necessary { - let item = match self.inner.poll() { - Async::Ready(item) => item, - Async::NotReady => return Async::NotReady, + let item = match self.inner.poll(cx) { + Poll::Ready(item) => item, + Poll::Pending => return Poll::Pending, }; match item { @@ -463,7 +464,7 @@ where match (user_data, result, handler) { (TaskState::Pending, tasks::Error::Reach(err), Some(handler)) => { - Async::Ready(CollectionEvent::ReachError { + Poll::Ready(CollectionEvent::ReachError { id: ReachAttemptId(id), error: err, handler, @@ -482,7 +483,7 @@ where debug_assert!(_handler.is_none()); let _node_task_id = self.nodes.remove(conn_info.peer_id()); debug_assert_eq!(_node_task_id, Some(id)); - Async::Ready(CollectionEvent::NodeClosed { + Poll::Ready(CollectionEvent::NodeClosed { conn_info, error: err, user_data, @@ -497,8 +498,8 @@ where tasks::Event::NodeReached { task, conn_info } => { let id = task.id(); drop(task); - Async::Ready(CollectionEvent::NodeReached(CollectionReachEvent { - parent: self, + Poll::Ready(CollectionEvent::NodeReached(CollectionReachEvent { + parent: &mut *self, id, conn_info: Some(conn_info), })) @@ -512,7 +513,7 @@ where self.tasks is switched to the Connected state; QED"), }; drop(task); - Async::Ready(CollectionEvent::NodeEvent { + Poll::Ready(CollectionEvent::NodeEvent { // TODO: normally we'd build a `PeerMut` manually here, but the borrow checker // doesn't like it peer: self.peer_mut(&conn_info.peer_id()) @@ -616,14 +617,15 @@ where } } - /// Sends an event to the given node. - pub fn start_send_event(&mut self, event: TInEvent) -> StartSend { + /// Begin sending an event to the given node. Must be called only after a successful call to + /// `poll_ready_event`. + pub fn start_send_event(&mut self, event: TInEvent) { self.inner.start_send_event(event) } - /// Complete sending an event message initiated by `start_send_event`. - pub fn complete_send_event(&mut self) -> Poll<(), ()> { - self.inner.complete_send_event() + /// Make sure we are ready to accept an event to be sent with `start_send_event`. + pub fn poll_ready_event(&mut self, cx: &mut Context) -> Poll<()> { + self.inner.poll_ready_event(cx) } /// Closes the connections to this node. Returns the user data. @@ -648,23 +650,13 @@ where /// The reach attempt will only be effectively cancelled once the peer (the object you're /// manipulating) has received some network activity. However no event will be ever be /// generated from this reach attempt, and this takes effect immediately. - #[must_use] - pub fn start_take_over(&mut self, id: InterruptedReachAttempt) - -> StartTakeOver<(), InterruptedReachAttempt> - { - match self.inner.start_take_over(id.inner) { - StartTakeOver::Ready(_state) => { - debug_assert!(if let TaskState::Pending = _state { true } else { false }); - StartTakeOver::Ready(()) - } - StartTakeOver::NotReady(inner) => - StartTakeOver::NotReady(InterruptedReachAttempt { inner }), - StartTakeOver::Gone => StartTakeOver::Gone - } + pub fn start_take_over(&mut self, id: InterruptedReachAttempt) { + self.inner.start_take_over(id.inner) } - /// Complete a take over initiated by `start_take_over`. - pub fn complete_take_over(&mut self) -> Poll<(), ()> { - self.inner.complete_take_over() + /// Make sure we are ready to taking over with `start_take_over`. + #[must_use] + pub fn poll_ready_take_over(&mut self, cx: &mut Context) -> Poll<()> { + self.inner.poll_ready_take_over(cx) } } diff --git a/core/src/nodes/collection/tests.rs b/core/src/nodes/collection/tests.rs deleted file mode 100644 index 69f82c05..00000000 --- a/core/src/nodes/collection/tests.rs +++ /dev/null @@ -1,373 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -#![cfg(test)] - -use super::*; -use assert_matches::assert_matches; -use futures::future; -use crate::tests::dummy_muxer::{DummyMuxer, DummyConnectionState}; -use crate::tests::dummy_handler::{Handler, InEvent, OutEvent, HandlerState}; -use tokio::runtime::current_thread::Runtime; -use tokio::runtime::Builder; -use crate::nodes::NodeHandlerEvent; -use std::{io, sync::Arc}; -use parking_lot::Mutex; - -type TestCollectionStream = CollectionStream; - -#[test] -fn has_connection_is_false_before_a_connection_has_been_made() { - let cs = TestCollectionStream::new(); - let peer_id = PeerId::random(); - assert!(!cs.has_connection(&peer_id)); -} - -#[test] -fn connections_is_empty_before_connecting() { - let cs = TestCollectionStream::new(); - assert!(cs.connections().next().is_none()); -} - -#[test] -fn retrieving_a_peer_is_none_if_peer_is_missing_or_not_connected() { - let mut cs = TestCollectionStream::new(); - let peer_id = PeerId::random(); - assert!(cs.peer_mut(&peer_id).is_none()); - - let handler = Handler::default(); - let fut = future::ok((peer_id.clone(), DummyMuxer::new())); - cs.add_reach_attempt(fut, handler); - assert!(cs.peer_mut(&peer_id).is_none()); // task is pending -} - -#[test] -fn collection_stream_reaches_the_nodes() { - let mut cs = TestCollectionStream::new(); - let peer_id = PeerId::random(); - - let mut muxer = DummyMuxer::new(); - muxer.set_inbound_connection_state(DummyConnectionState::Pending); - muxer.set_outbound_connection_state(DummyConnectionState::Opened); - - let fut = future::ok((peer_id, muxer)); - cs.add_reach_attempt(fut, Handler::default()); - let mut rt = Runtime::new().unwrap(); - let mut poll_count = 0; - let fut = future::poll_fn(move || -> Poll<(), ()> { - poll_count += 1; - let event = cs.poll(); - match poll_count { - 1 => assert_matches!(event, Async::NotReady), - 2 => { - assert_matches!(event, Async::Ready(CollectionEvent::NodeReached(_))); - return Ok(Async::Ready(())); // stop - } - _ => unreachable!() - } - Ok(Async::NotReady) - }); - rt.block_on(fut).unwrap(); -} - -#[test] -fn accepting_a_node_yields_new_entry() { - let mut cs = TestCollectionStream::new(); - let peer_id = PeerId::random(); - let fut = future::ok((peer_id.clone(), DummyMuxer::new())); - cs.add_reach_attempt(fut, Handler::default()); - - let mut rt = Runtime::new().unwrap(); - let mut poll_count = 0; - let fut = future::poll_fn(move || -> Poll<(), ()> { - poll_count += 1; - { - let event = cs.poll(); - match poll_count { - 1 => { - assert_matches!(event, Async::NotReady); - return Ok(Async::NotReady) - } - 2 => { - assert_matches!(event, Async::Ready(CollectionEvent::NodeReached(reach_ev)) => { - let (accept_ev, accepted_peer_id) = reach_ev.accept(()); - assert_eq!(accepted_peer_id, peer_id); - assert_matches!(accept_ev, CollectionNodeAccept::NewEntry); - }); - } - _ => unreachable!() - } - } - assert!(cs.peer_mut(&peer_id).is_some(), "peer is not in the list"); - assert!(cs.has_connection(&peer_id), "peer is not connected"); - assert_eq!(cs.connections().collect::>(), vec![&peer_id]); - Ok(Async::Ready(())) - }); - rt.block_on(fut).expect("running the future works"); -} - -#[test] -fn events_in_a_node_reaches_the_collection_stream() { - let cs = Arc::new(Mutex::new(TestCollectionStream::new())); - let task_peer_id = PeerId::random(); - - let mut handler = Handler::default(); - handler.state = Some(HandlerState::Ready(NodeHandlerEvent::Custom(OutEvent::Custom("init")))); - let handler_states = vec![ - HandlerState::Err, - HandlerState::Ready(NodeHandlerEvent::Custom(OutEvent::Custom("from handler 3") )), - HandlerState::Ready(NodeHandlerEvent::Custom(OutEvent::Custom("from handler 2") )), - HandlerState::Ready(NodeHandlerEvent::Custom(OutEvent::Custom("from handler 1") )), - ]; - handler.next_states = handler_states; - - let mut muxer = DummyMuxer::new(); - muxer.set_inbound_connection_state(DummyConnectionState::Pending); - muxer.set_outbound_connection_state(DummyConnectionState::Opened); - - let fut = future::ok((task_peer_id.clone(), muxer)); - cs.lock().add_reach_attempt(fut, handler); - - let mut rt = Builder::new().core_threads(1).build().unwrap(); - - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - assert_matches!(cs.poll(), Async::NotReady); - Ok(Async::Ready(())) - })).expect("tokio works"); - - let cs2 = cs.clone(); - rt.block_on(future::poll_fn(move || { - if cs2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - if cs.complete_broadcast().is_not_ready() { - return Ok(Async::NotReady) - } - assert_matches!(cs.poll(), Async::Ready(CollectionEvent::NodeReached(reach_ev)) => { - reach_ev.accept(()); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); - - let cs2 = cs.clone(); - rt.block_on(future::poll_fn(move || { - if cs2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - if cs.complete_broadcast().is_not_ready() { - return Ok(Async::NotReady) - } - assert_matches!(cs.poll(), Async::Ready(CollectionEvent::NodeEvent{peer: _, event}) => { - assert_matches!(event, OutEvent::Custom("init")); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); - - - let cs2 = cs.clone(); - rt.block_on(future::poll_fn(move || { - if cs2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - if cs.complete_broadcast().is_not_ready() { - return Ok(Async::NotReady) - } - assert_matches!(cs.poll(), Async::Ready(CollectionEvent::NodeEvent{peer: _, event}) => { - assert_matches!(event, OutEvent::Custom("from handler 1")); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); - - let cs2 = cs.clone(); - rt.block_on(future::poll_fn(move || { - if cs2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - if cs.complete_broadcast().is_not_ready() { - return Ok(Async::NotReady) - } - assert_matches!(cs.poll(), Async::Ready(CollectionEvent::NodeEvent{peer: _, event}) => { - assert_matches!(event, OutEvent::Custom("from handler 2")); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); -} - -#[test] -fn task_closed_with_error_while_task_is_pending_yields_reach_error() { - let cs = Arc::new(Mutex::new(TestCollectionStream::new())); - let task_inner_fut = future::err(std::io::Error::new(std::io::ErrorKind::Other, "inner fut error")); - let reach_attempt_id = cs.lock().add_reach_attempt(task_inner_fut, Handler::default()); - - let mut rt = Builder::new().core_threads(1).build().unwrap(); - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - assert_matches!(cs.poll(), Async::NotReady); - Ok(Async::Ready(())) - })).expect("tokio works"); - - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - assert_matches!(cs.poll(), Async::Ready(collection_ev) => { - assert_matches!(collection_ev, CollectionEvent::ReachError {id, error, ..} => { - assert_eq!(id, reach_attempt_id); - assert_eq!(error.to_string(), "inner fut error"); - }); - - }); - Ok(Async::Ready(())) - })).expect("tokio works"); - -} - -#[test] -fn task_closed_with_error_when_task_is_connected_yields_node_error() { - let cs = Arc::new(Mutex::new(TestCollectionStream::new())); - let peer_id = PeerId::random(); - let muxer = DummyMuxer::new(); - let task_inner_fut = future::ok((peer_id.clone(), muxer)); - let mut handler = Handler::default(); - handler.next_states = vec![HandlerState::Err]; // triggered when sending a NextState event - - cs.lock().add_reach_attempt(task_inner_fut, handler); - let mut rt = Builder::new().core_threads(1).build().unwrap(); - - // Kick it off - let cs2 = cs.clone(); - rt.block_on(future::poll_fn(move || { - if cs2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - assert_matches!(cs.poll(), Async::NotReady); - // send an event so the Handler errors in two polls - Ok(cs.complete_broadcast()) - })).expect("tokio works"); - - // Accept the new node - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - // NodeReached, accept the connection so the task transitions from Pending to Connected - assert_matches!(cs.poll(), Async::Ready(CollectionEvent::NodeReached(reach_ev)) => { - reach_ev.accept(()); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); - - assert!(cs.lock().has_connection(&peer_id)); - - // Assert the node errored - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - assert_matches!(cs.poll(), Async::Ready(collection_ev) => { - assert_matches!(collection_ev, CollectionEvent::NodeClosed{..}); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); -} - -#[test] -fn interrupting_a_pending_connection_attempt_is_ok() { - let mut cs = TestCollectionStream::new(); - let fut = future::empty(); - let reach_id = cs.add_reach_attempt(fut, Handler::default()); - let interrupt = cs.interrupt(reach_id); - assert!(interrupt.is_ok()); -} - -#[test] -fn interrupting_a_connection_attempt_twice_is_err() { - let mut cs = TestCollectionStream::new(); - let fut = future::empty(); - let reach_id = cs.add_reach_attempt(fut, Handler::default()); - assert!(cs.interrupt(reach_id).is_ok()); - assert_matches!(cs.interrupt(reach_id), Err(InterruptError::ReachAttemptNotFound)) -} - -#[test] -fn interrupting_an_established_connection_is_err() { - let cs = Arc::new(Mutex::new(TestCollectionStream::new())); - let peer_id = PeerId::random(); - let muxer = DummyMuxer::new(); - let task_inner_fut = future::ok((peer_id.clone(), muxer)); - let handler = Handler::default(); - - let reach_id = cs.lock().add_reach_attempt(task_inner_fut, handler); - let mut rt = Builder::new().core_threads(1).build().unwrap(); - - // Kick it off - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - assert_matches!(cs.poll(), Async::NotReady); - // send an event so the Handler errors in two polls - Ok(Async::Ready(())) - })).expect("tokio works"); - - // Accept the new node - let cs_fut = cs.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut cs = cs_fut.lock(); - // NodeReached, accept the connection so the task transitions from Pending to Connected - assert_matches!(cs.poll(), Async::Ready(CollectionEvent::NodeReached(reach_ev)) => { - reach_ev.accept(()); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); - - assert!(cs.lock().has_connection(&peer_id), "Connection was not established"); - - assert_matches!(cs.lock().interrupt(reach_id), Err(InterruptError::AlreadyReached)); -} diff --git a/core/src/nodes/handled_node.rs b/core/src/nodes/handled_node.rs index 150b5e45..f8b08d11 100644 --- a/core/src/nodes/handled_node.rs +++ b/core/src/nodes/handled_node.rs @@ -20,10 +20,7 @@ use crate::{PeerId, muxing::StreamMuxer}; use crate::nodes::node::{NodeEvent, NodeStream, Substream, Close}; -use futures::prelude::*; -use std::{error, fmt, io}; - -mod tests; +use std::{error, fmt, io, pin::Pin, task::Context, task::Poll}; /// Handler for the substreams of a node. // TODO: right now it is possible for a node handler to be built, then shut down right after if we @@ -59,7 +56,8 @@ pub trait NodeHandler { /// Should behave like `Stream::poll()`. /// /// Returning an error will close the connection to the remote. - fn poll(&mut self) -> Poll, Self::Error>; + fn poll(&mut self, cx: &mut Context) + -> Poll, Self::Error>>; } /// Prototype for a `NodeHandler`. @@ -172,6 +170,13 @@ where } } +impl Unpin for HandledNode +where + TMuxer: StreamMuxer, + THandler: NodeHandler>, +{ +} + impl HandledNode where TMuxer: StreamMuxer, @@ -214,37 +219,41 @@ where } /// API similar to `Future::poll` that polls the node for events. - pub fn poll(&mut self) -> Poll> { + pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context) + -> Poll>> + { loop { let mut node_not_ready = false; - match self.node.poll().map_err(HandledNodeError::Node)? { - Async::NotReady => node_not_ready = true, - Async::Ready(NodeEvent::InboundSubstream { substream }) => { + match self.node.poll(cx) { + Poll::Pending => node_not_ready = true, + Poll::Ready(Ok(NodeEvent::InboundSubstream { substream })) => { self.handler.inject_substream(substream, NodeHandlerEndpoint::Listener) } - Async::Ready(NodeEvent::OutboundSubstream { user_data, substream }) => { + Poll::Ready(Ok(NodeEvent::OutboundSubstream { user_data, substream })) => { let endpoint = NodeHandlerEndpoint::Dialer(user_data); self.handler.inject_substream(substream, endpoint) } + Poll::Ready(Err(err)) => return Poll::Ready(Err(HandledNodeError::Node(err))), } - match self.handler.poll().map_err(HandledNodeError::Handler)? { - Async::NotReady => { + match self.handler.poll(cx) { + Poll::Pending => { if node_not_ready { break } } - Async::Ready(NodeHandlerEvent::OutboundSubstreamRequest(user_data)) => { + Poll::Ready(Ok(NodeHandlerEvent::OutboundSubstreamRequest(user_data))) => { self.node.open_substream(user_data); } - Async::Ready(NodeHandlerEvent::Custom(event)) => { - return Ok(Async::Ready(event)); + Poll::Ready(Ok(NodeHandlerEvent::Custom(event))) => { + return Poll::Ready(Ok(event)); } + Poll::Ready(Err(err)) => return Poll::Ready(Err(HandledNodeError::Handler(err))), } } - Ok(Async::NotReady) + Poll::Pending } } diff --git a/core/src/nodes/handled_node/tests.rs b/core/src/nodes/handled_node/tests.rs deleted file mode 100644 index ee138c2e..00000000 --- a/core/src/nodes/handled_node/tests.rs +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -#![cfg(test)] - -use super::*; -use assert_matches::assert_matches; -use crate::tests::dummy_muxer::{DummyMuxer, DummyConnectionState}; -use crate::tests::dummy_handler::{Handler, HandlerState, InEvent, OutEvent, TestHandledNode}; - -struct TestBuilder { - muxer: DummyMuxer, - handler: Handler, - want_open_substream: bool, - substream_user_data: usize, -} - -impl TestBuilder { - fn new() -> Self { - TestBuilder { - muxer: DummyMuxer::new(), - handler: Handler::default(), - want_open_substream: false, - substream_user_data: 0, - } - } - - fn with_muxer_inbound_state(&mut self, state: DummyConnectionState) -> &mut Self { - self.muxer.set_inbound_connection_state(state); - self - } - - fn with_muxer_outbound_state(&mut self, state: DummyConnectionState) -> &mut Self { - self.muxer.set_outbound_connection_state(state); - self - } - - fn with_handler_state(&mut self, state: HandlerState) -> &mut Self { - self.handler.state = Some(state); - self - } - - fn with_open_substream(&mut self, user_data: usize) -> &mut Self { - self.want_open_substream = true; - self.substream_user_data = user_data; - self - } - - fn handled_node(&mut self) -> TestHandledNode { - let mut h = HandledNode::new(self.muxer.clone(), self.handler.clone()); - if self.want_open_substream { - h.node.open_substream(self.substream_user_data); - } - h - } -} - -// Set the state of the `Handler` after `inject_outbound_closed` is called -fn set_next_handler_outbound_state( handled_node: &mut TestHandledNode, next_state: HandlerState) { - handled_node.handler.next_outbound_state = Some(next_state); -} - -#[test] -fn can_inject_event() { - let mut handled = TestBuilder::new() - .handled_node(); - - let event = InEvent::Custom("banana"); - handled.inject_event(event.clone()); - assert_eq!(handled.handler().events, vec![event]); -} - -#[test] -fn poll_with_unready_node_stream_and_handler_emits_custom_event() { - let expected_event = NodeHandlerEvent::Custom(OutEvent::Custom("pineapple")); - let mut handled = TestBuilder::new() - // make NodeStream return NotReady - .with_muxer_inbound_state(DummyConnectionState::Pending) - // make Handler return return Ready(Some(…)) - .with_handler_state(HandlerState::Ready(expected_event)) - .handled_node(); - - assert_matches!(handled.poll(), Ok(Async::Ready(event)) => { - assert_matches!(event, OutEvent::Custom("pineapple")) - }); -} - -#[test] -fn handler_emits_outbound_closed_when_opening_new_substream_on_closed_node() { - let open_event = NodeHandlerEvent::OutboundSubstreamRequest(456); - let mut handled = TestBuilder::new() - .with_muxer_inbound_state(DummyConnectionState::Pending) - .with_muxer_outbound_state(DummyConnectionState::Pending) - .with_handler_state(HandlerState::Ready(open_event)) - .handled_node(); - - set_next_handler_outbound_state( - &mut handled, - HandlerState::Ready(NodeHandlerEvent::Custom(OutEvent::Custom("pear"))) - ); - handled.poll().expect("poll works"); -} - -#[test] -fn poll_yields_inbound_closed_event() { - let mut h = TestBuilder::new() - .with_muxer_inbound_state(DummyConnectionState::Pending) - .with_handler_state(HandlerState::Err) // stop the loop - .handled_node(); - - assert_eq!(h.handler().events, vec![]); - let _ = h.poll(); -} - -#[test] -fn poll_yields_outbound_closed_event() { - let mut h = TestBuilder::new() - .with_muxer_inbound_state(DummyConnectionState::Pending) - .with_open_substream(32) - .with_muxer_outbound_state(DummyConnectionState::Pending) - .with_handler_state(HandlerState::Err) // stop the loop - .handled_node(); - - assert_eq!(h.handler().events, vec![]); - let _ = h.poll(); -} - -#[test] -fn poll_yields_outbound_substream() { - let mut h = TestBuilder::new() - .with_muxer_inbound_state(DummyConnectionState::Pending) - .with_muxer_outbound_state(DummyConnectionState::Opened) - .with_open_substream(1) - .with_handler_state(HandlerState::Err) // stop the loop - .handled_node(); - - assert_eq!(h.handler().events, vec![]); - let _ = h.poll(); - assert_eq!(h.handler().events, vec![InEvent::Substream(Some(1))]); -} - -#[test] -fn poll_yields_inbound_substream() { - let mut h = TestBuilder::new() - .with_muxer_inbound_state(DummyConnectionState::Opened) - .with_muxer_outbound_state(DummyConnectionState::Pending) - .with_handler_state(HandlerState::Err) // stop the loop - .handled_node(); - - assert_eq!(h.handler().events, vec![]); - let _ = h.poll(); - assert_eq!(h.handler().events, vec![InEvent::Substream(None)]); -} diff --git a/core/src/nodes/listeners.rs b/core/src/nodes/listeners.rs index effcea65..b9c8ebbf 100644 --- a/core/src/nodes/listeners.rs +++ b/core/src/nodes/listeners.rs @@ -21,11 +21,10 @@ //! Manage listening on multiple multiaddresses at once. use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}}; -use futures::prelude::*; +use futures::{prelude::*, task::Context, task::Poll}; use log::debug; use smallvec::SmallVec; -use std::{collections::VecDeque, fmt}; -use void::Void; +use std::{collections::VecDeque, fmt, pin::Pin}; /// Implementation of `futures::Stream` that allows listening on multiaddresses. /// @@ -158,7 +157,7 @@ where /// The ID of the listener that errored. listener_id: ListenerId, /// The error value. - error: ::Error + error: ::Error } } @@ -222,28 +221,31 @@ where self.listeners.iter().flat_map(|l| l.addresses.iter()) } - /// Provides an API similar to `Stream`, except that it cannot error. - pub fn poll(&mut self) -> Async> { + /// Provides an API similar to `Stream`, except that it cannot end. + pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> + where + TTrans::Listener: Unpin, + { // We remove each element from `listeners` one by one and add them back. let mut remaining = self.listeners.len(); while let Some(mut listener) = self.listeners.pop_back() { - match listener.listener.poll() { - Ok(Async::NotReady) => { + match TryStream::try_poll_next(Pin::new(&mut listener.listener), cx) { + Poll::Pending => { self.listeners.push_front(listener); remaining -= 1; if remaining == 0 { break } } - Ok(Async::Ready(Some(ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }))) => { + Poll::Ready(Some(Ok(ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }))) => { let id = listener.id; self.listeners.push_front(listener); - return Async::Ready(ListenersEvent::Incoming { + return Poll::Ready(ListenersEvent::Incoming { listener_id: id, upgrade, local_addr, send_back_addr: remote_addr }) } - Ok(Async::Ready(Some(ListenerEvent::NewAddress(a)))) => { + Poll::Ready(Some(Ok(ListenerEvent::NewAddress(a)))) => { if listener.addresses.contains(&a) { debug!("Transport has reported address {} multiple times", a) } @@ -252,28 +254,28 @@ where } let id = listener.id; self.listeners.push_front(listener); - return Async::Ready(ListenersEvent::NewAddress { + return Poll::Ready(ListenersEvent::NewAddress { listener_id: id, listen_addr: a }) } - Ok(Async::Ready(Some(ListenerEvent::AddressExpired(a)))) => { + Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(a)))) => { listener.addresses.retain(|x| x != &a); let id = listener.id; self.listeners.push_front(listener); - return Async::Ready(ListenersEvent::AddressExpired { + return Poll::Ready(ListenersEvent::AddressExpired { listener_id: id, listen_addr: a }) } - Ok(Async::Ready(None)) => { - return Async::Ready(ListenersEvent::Closed { + Poll::Ready(None) => { + return Poll::Ready(ListenersEvent::Closed { listener_id: listener.id, listener: listener.listener }) } - Err(err) => { - return Async::Ready(ListenersEvent::Error { + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(ListenersEvent::Error { listener_id: listener.id, error: err }) @@ -282,22 +284,28 @@ where } // We register the current task to be woken up if a new listener is added. - Async::NotReady + Poll::Pending } } impl Stream for ListenersStream where TTrans: Transport, + TTrans::Listener: Unpin, { type Item = ListenersEvent; - type Error = Void; // TODO: use ! once stable - fn poll(&mut self) -> Poll, Self::Error> { - Ok(self.poll().map(Option::Some)) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + ListenersStream::poll(self, cx).map(Option::Some) } } +impl Unpin for ListenersStream +where + TTrans: Transport, +{ +} + impl fmt::Debug for ListenersStream where TTrans: Transport + fmt::Debug, @@ -313,7 +321,7 @@ where impl fmt::Debug for ListenersEvent where TTrans: Transport, - ::Error: fmt::Debug, + ::Error: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { @@ -353,215 +361,37 @@ mod tests { use tokio::runtime::current_thread::Runtime; use std::{io, iter::FromIterator}; use futures::{future::{self}, stream}; - use crate::tests::dummy_transport::{DummyTransport, ListenerState}; - use crate::tests::dummy_muxer::DummyMuxer; use crate::PeerId; - fn set_listener_state(ls: &mut ListenersStream, idx: usize, state: ListenerState) { - ls.listeners[idx].listener = match state { - ListenerState::Error => - Box::new(stream::poll_fn(|| Err(io::Error::new(io::ErrorKind::Other, "oh noes")))), - ListenerState::Ok(state) => match state { - Async::NotReady => Box::new(stream::poll_fn(|| Ok(Async::NotReady))), - Async::Ready(Some(event)) => Box::new(stream::poll_fn(move || { - Ok(Async::Ready(Some(event.clone().map(future::ok)))) - })), - Async::Ready(None) => Box::new(stream::empty()) - } - ListenerState::Events(events) => - Box::new(stream::iter_ok(events.into_iter().map(|e| e.map(future::ok)))) - }; - } - #[test] fn incoming_event() { - let mem_transport = transport::MemoryTransport::default(); + futures::executor::block_on(async move { + let mem_transport = transport::MemoryTransport::default(); - let mut listeners = ListenersStream::new(mem_transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); + let mut listeners = ListenersStream::new(mem_transport); + listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - let address = { - let event = listeners.by_ref().wait().next().expect("some event").expect("no error"); - if let ListenersEvent::NewAddress { listen_addr, .. } = event { - listen_addr - } else { - panic!("Was expecting the listen address to be reported") - } - }; - - let dial = mem_transport.dial(address.clone()).unwrap(); - - let future = listeners - .into_future() - .map_err(|(err, _)| err) - .and_then(|(event, _)| { - match event { - Some(ListenersEvent::Incoming { local_addr, upgrade, send_back_addr, .. }) => { - assert_eq!(local_addr, address); - assert_eq!(send_back_addr, address); - upgrade.map(|_| ()).map_err(|_| panic!()) - }, - _ => panic!() + let address = { + let event = listeners.next().await.unwrap(); + if let ListenersEvent::NewAddress { listen_addr, .. } = event { + listen_addr + } else { + panic!("Was expecting the listen address to be reported") } - }) - .select(dial.map(|_| ()).map_err(|_| panic!())) - .map_err(|(err, _)| err); + }; - let mut runtime = Runtime::new().unwrap(); - runtime.block_on(future).unwrap(); - } + let address2 = address.clone(); + async_std::task::spawn(async move { + mem_transport.dial(address2).unwrap().await.unwrap(); + }); - #[test] - fn listener_stream_returns_transport() { - let t = DummyTransport::new(); - let t_clone = t.clone(); - let ls = ListenersStream::new(t); - assert_eq!(ls.transport(), &t_clone); - } - - #[test] - fn listener_stream_can_iterate_over_listeners() { - let mut t = DummyTransport::new(); - let addr1 = tcp4([127, 0, 0, 1], 1234); - let addr2 = tcp4([127, 0, 0, 1], 4321); - - t.set_initial_listener_state(ListenerState::Events(vec![ - ListenerEvent::NewAddress(addr1.clone()), - ListenerEvent::NewAddress(addr2.clone()) - ])); - - let mut ls = ListenersStream::new(t); - ls.listen_on(tcp4([0, 0, 0, 0], 0)).expect("listen_on"); - - assert_matches!(ls.by_ref().wait().next(), Some(Ok(ListenersEvent::NewAddress { listen_addr, .. })) => { - assert_eq!(addr1, listen_addr) - }); - assert_matches!(ls.by_ref().wait().next(), Some(Ok(ListenersEvent::NewAddress { listen_addr, .. })) => { - assert_eq!(addr2, listen_addr) - }) - } - - #[test] - fn listener_stream_poll_without_listeners_is_not_ready() { - let t = DummyTransport::new(); - let mut ls = ListenersStream::new(t); - assert_matches!(ls.poll(), Async::NotReady); - } - - #[test] - fn listener_stream_poll_with_listeners_that_arent_ready_is_not_ready() { - let t = DummyTransport::new(); - let addr = tcp4([127, 0, 0, 1], 1234); - let mut ls = ListenersStream::new(t); - ls.listen_on(addr).expect("listen_on failed"); - set_listener_state(&mut ls, 0, ListenerState::Ok(Async::NotReady)); - assert_matches!(ls.poll(), Async::NotReady); - assert_eq!(ls.listeners.len(), 1); // listener is still there - } - - #[test] - fn listener_stream_poll_with_ready_listeners_is_ready() { - let mut t = DummyTransport::new(); - let peer_id = PeerId::random(); - let muxer = DummyMuxer::new(); - let expected_output = (peer_id.clone(), muxer.clone()); - - t.set_initial_listener_state(ListenerState::Events(vec![ - ListenerEvent::NewAddress(tcp4([127, 0, 0, 1], 9090)), - ListenerEvent::Upgrade { - upgrade: (peer_id.clone(), muxer.clone()), - local_addr: tcp4([127, 0, 0, 1], 9090), - remote_addr: tcp4([127, 0, 0, 1], 32000) - }, - ListenerEvent::Upgrade { - upgrade: (peer_id.clone(), muxer.clone()), - local_addr: tcp4([127, 0, 0, 1], 9090), - remote_addr: tcp4([127, 0, 0, 1], 32000) - }, - ListenerEvent::Upgrade { - upgrade: (peer_id.clone(), muxer.clone()), - local_addr: tcp4([127, 0, 0, 1], 9090), - remote_addr: tcp4([127, 0, 0, 1], 32000) + match listeners.next().await.unwrap() { + ListenersEvent::Incoming { local_addr, upgrade, send_back_addr, .. } => { + assert_eq!(local_addr, address); + assert_eq!(send_back_addr, address); + }, + _ => panic!() } - ])); - - let mut ls = ListenersStream::new(t); - ls.listen_on(tcp4([127, 0, 0, 1], 1234)).expect("listen_on"); - ls.listen_on(tcp4([127, 0, 0, 1], 4321)).expect("listen_on"); - assert_eq!(ls.listeners.len(), 2); - - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::NewAddress { .. }) }); - - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::NewAddress { .. }) - }); - - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::Incoming { upgrade, .. } => { - assert_matches!(upgrade.wait(), Ok(output) => { - assert_eq!(output, expected_output) - }); - }) - }); - - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::Incoming { upgrade, .. } => { - assert_matches!(upgrade.wait(), Ok(output) => { - assert_eq!(output, expected_output) - }); - }) - }); - - set_listener_state(&mut ls, 1, ListenerState::Ok(Async::NotReady)); - - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::Incoming { upgrade, .. } => { - assert_matches!(upgrade.wait(), Ok(output) => { - assert_eq!(output, expected_output) - }); - }) - }); - } - - #[test] - fn listener_stream_poll_with_closed_listener_emits_closed_event() { - let t = DummyTransport::new(); - let addr = tcp4([127, 0, 0, 1], 1234); - let mut ls = ListenersStream::new(t); - ls.listen_on(addr).expect("listen_on failed"); - set_listener_state(&mut ls, 0, ListenerState::Ok(Async::Ready(None))); - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::Closed{..}) - }); - assert_eq!(ls.listeners.len(), 0); // it's gone - } - - #[test] - fn listener_stream_poll_with_erroring_listener_emits_error_event() { - let mut t = DummyTransport::new(); - let peer_id = PeerId::random(); - let muxer = DummyMuxer::new(); - let event = ListenerEvent::Upgrade { - upgrade: (peer_id, muxer), - local_addr: tcp4([127, 0, 0, 1], 1234), - remote_addr: tcp4([127, 0, 0, 1], 32000) - }; - t.set_initial_listener_state(ListenerState::Ok(Async::Ready(Some(event)))); - let addr = tcp4([127, 0, 0, 1], 1234); - let mut ls = ListenersStream::new(t); - ls.listen_on(addr).expect("listen_on failed"); - set_listener_state(&mut ls, 0, ListenerState::Error); // simulate an error on the socket - assert_matches!(ls.by_ref().wait().next(), Some(Ok(listeners_event)) => { - assert_matches!(listeners_event, ListenersEvent::Error{..}) - }); - assert_eq!(ls.listeners.len(), 0); // it's gone - } - - fn tcp4(ip: [u8; 4], port: u16) -> Multiaddr { - let protos = std::iter::once(multiaddr::Protocol::Ip4(ip.into())) - .chain(std::iter::once(multiaddr::Protocol::Tcp(port))); - Multiaddr::from_iter(protos) } } diff --git a/core/src/nodes/network.rs b/core/src/nodes/network.rs index abe9e631..2f6634a1 100644 --- a/core/src/nodes/network.rs +++ b/core/src/nodes/network.rs @@ -49,10 +49,10 @@ use std::{ fmt, hash::Hash, num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, }; -pub use crate::nodes::collection::StartTakeOver; - mod tests; /// Implementation of `Stream` that handles the nodes. @@ -81,7 +81,7 @@ where /// If the pair's second element is `AsyncSink::Ready`, the take over /// message has been sent and needs to be flushed using /// `PeerMut::complete_take_over`. - take_over_to_complete: Option<(TPeerId, AsyncSink>)> + take_over_to_complete: Option<(TPeerId, InterruptedReachAttempt)> } impl fmt::Debug for @@ -102,6 +102,13 @@ where } } +impl Unpin for + Network +where + TTrans: Transport +{ +} + impl ConnectionInfo for (TConnInfo, ConnectedPoint) where TConnInfo: ConnectionInfo @@ -173,7 +180,7 @@ where /// The listener that errored. listener_id: ListenerId, /// The listener error. - error: ::Error + error: ::Error }, /// One of the listeners is now listening on an additional address. @@ -573,7 +580,7 @@ impl<'a, TTrans, TInEvent, TOutEvent, TMuxer, THandler, THandlerErr, TConnInfo, where TTrans: Transport, TTrans::Error: Send + 'static, - TTrans::ListenerUpgrade: Send + 'static, + TTrans::ListenerUpgrade: Unpin + Send + 'static, THandler: IntoNodeHandler<(TConnInfo, ConnectedPoint)> + Send + 'static, THandler::Handler: NodeHandler, InEvent = TInEvent, OutEvent = TOutEvent, Error = THandlerErr> + Send + 'static, ::OutboundOpenInfo: Send + 'static, // TODO: shouldn't be necessary @@ -609,9 +616,9 @@ where let connected_point = connected_point.clone(); move |(peer_id, muxer)| { if *peer_id.peer_id() == local_peer_id { - Err(InternalReachErr::FoundLocalPeerId) + future::ready(Err(InternalReachErr::FoundLocalPeerId)) } else { - Ok(((peer_id, connected_point), muxer)) + future::ready(Ok(((peer_id, connected_point), muxer))) } } }); @@ -781,7 +788,7 @@ where where TTrans: Transport, TTrans::Error: Send + 'static, - TTrans::Dial: Send + 'static, + TTrans::Dial: Unpin + Send + 'static, TMuxer: Send + Sync + 'static, TMuxer::OutboundSubstream: Send, TInEvent: Send + 'static, @@ -797,9 +804,9 @@ where let connected_point = connected_point.clone(); move |(peer_id, muxer)| { if *peer_id.peer_id() == local_peer_id { - Err(InternalReachErr::FoundLocalPeerId) + future::ready(Err(InternalReachErr::FoundLocalPeerId)) } else { - Ok(((peer_id, connected_point), muxer)) + future::ready(Ok(((peer_id, connected_point), muxer))) } } }); @@ -840,19 +847,18 @@ where /// Start sending an event to all nodes. /// - /// Make sure to complete the broadcast with `complete_broadcast`. - #[must_use] - pub fn start_broadcast(&mut self, event: &TInEvent) -> AsyncSink<()> + /// Must be called only after a successful call to `poll_ready_broadcast`. + pub fn start_broadcast(&mut self, event: &TInEvent) where TInEvent: Clone { self.active_nodes.start_broadcast(event) } - /// Complete a broadcast initiated with `start_broadcast`. + /// Wait until we have enough room in senders to broadcast an event. #[must_use] - pub fn complete_broadcast(&mut self) -> Async<()> { - self.active_nodes.complete_broadcast() + pub fn poll_ready_broadcast(&mut self, cx: &mut Context) -> Poll<()> { + self.active_nodes.poll_ready_broadcast(cx) } /// Returns a list of all the peers we are currently connected to. @@ -934,7 +940,7 @@ where fn start_dial_out(&mut self, peer_id: TPeerId, handler: THandler, first: Multiaddr, rest: Vec) where TTrans: Transport, - TTrans::Dial: Send + 'static, + TTrans::Dial: Unpin + Send + 'static, TTrans::Error: Send + 'static, TMuxer: Send + Sync + 'static, TMuxer::OutboundSubstream: Send, @@ -950,9 +956,9 @@ where .map_err(|err| InternalReachErr::Transport(TransportError::Other(err))) .and_then(move |(actual_conn_info, muxer)| { if *actual_conn_info.peer_id() == expected_peer_id { - Ok(((actual_conn_info, connected_point), muxer)) + future::ready(Ok(((actual_conn_info, connected_point), muxer))) } else { - Err(InternalReachErr::PeerIdMismatch { obtained: actual_conn_info }) + future::ready(Err(InternalReachErr::PeerIdMismatch { obtained: actual_conn_info })) } }); self.active_nodes.add_reach_attempt(fut, handler) @@ -976,11 +982,12 @@ where } /// Provides an API similar to `Stream`, except that it cannot error. - pub fn poll(&mut self) -> Async> + pub fn poll<'a>(&'a mut self, cx: &mut Context) -> Poll> where TTrans: Transport, TTrans::Error: Send + 'static, - TTrans::Dial: Send + 'static, + TTrans::Dial: Unpin + Send + 'static, + TTrans::Listener: Unpin, TTrans::ListenerUpgrade: Send + 'static, TMuxer: Send + Sync + 'static, TMuxer::OutboundSubstream: Send, @@ -998,9 +1005,9 @@ where Some(x) if self.incoming_negotiated().count() >= (x as usize) => (), _ => { - match self.listeners.poll() { - Async::NotReady => (), - Async::Ready(ListenersEvent::Incoming { listener_id, upgrade, local_addr, send_back_addr }) => { + match ListenersStream::poll(Pin::new(&mut self.listeners), cx) { + Poll::Pending => (), + Poll::Ready(ListenersEvent::Incoming { listener_id, upgrade, local_addr, send_back_addr }) => { let event = IncomingConnectionEvent { listener_id, upgrade, @@ -1010,19 +1017,19 @@ where active_nodes: &mut self.active_nodes, other_reach_attempts: &mut self.reach_attempts.other_reach_attempts, }; - return Async::Ready(NetworkEvent::IncomingConnection(event)); + return Poll::Ready(NetworkEvent::IncomingConnection(event)); } - Async::Ready(ListenersEvent::NewAddress { listener_id, listen_addr }) => { - return Async::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) + Poll::Ready(ListenersEvent::NewAddress { listener_id, listen_addr }) => { + return Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) } - Async::Ready(ListenersEvent::AddressExpired { listener_id, listen_addr }) => { - return Async::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) + Poll::Ready(ListenersEvent::AddressExpired { listener_id, listen_addr }) => { + return Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) } - Async::Ready(ListenersEvent::Closed { listener_id, listener }) => { - return Async::Ready(NetworkEvent::ListenerClosed { listener_id, listener }) + Poll::Ready(ListenersEvent::Closed { listener_id, listener }) => { + return Poll::Ready(NetworkEvent::ListenerClosed { listener_id, listener }) } - Async::Ready(ListenersEvent::Error { listener_id, error }) => { - return Async::Ready(NetworkEvent::ListenerError { listener_id, error }) + Poll::Ready(ListenersEvent::Error { listener_id, error }) => { + return Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) } } } @@ -1031,36 +1038,30 @@ where // Attempt to deliver any pending take over messages. if let Some((id, interrupted)) = self.take_over_to_complete.take() { if let Some(mut peer) = self.active_nodes.peer_mut(&id) { - if let AsyncSink::NotReady(i) = interrupted { - if let StartTakeOver::NotReady(i) = peer.start_take_over(i) { - self.take_over_to_complete = Some((id, AsyncSink::NotReady(i))) - } else if let Ok(Async::NotReady) = peer.complete_take_over() { - self.take_over_to_complete = Some((id, AsyncSink::Ready)) - } - } else if let Ok(Async::NotReady) = peer.complete_take_over() { - self.take_over_to_complete = Some((id, AsyncSink::Ready)) + if let Poll::Ready(()) = peer.poll_ready_take_over(cx) { + peer.start_take_over(interrupted); + } else { + self.take_over_to_complete = Some((id, interrupted)); + return Poll::Pending; } } } - if self.take_over_to_complete.is_some() { - return Async::NotReady - } // Poll the existing nodes. let (action, out_event); - match self.active_nodes.poll() { - Async::NotReady => return Async::NotReady, - Async::Ready(CollectionEvent::NodeReached(reach_event)) => { + match self.active_nodes.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(CollectionEvent::NodeReached(reach_event)) => { let (a, e) = handle_node_reached(&mut self.reach_attempts, reach_event); action = a; out_event = e; } - Async::Ready(CollectionEvent::ReachError { id, error, handler }) => { + Poll::Ready(CollectionEvent::ReachError { id, error, handler }) => { let (a, e) = handle_reach_error(&mut self.reach_attempts, id, error, handler); action = a; out_event = e; } - Async::Ready(CollectionEvent::NodeClosed { + Poll::Ready(CollectionEvent::NodeClosed { conn_info, error, .. @@ -1078,7 +1079,7 @@ where error, }; } - Async::Ready(CollectionEvent::NodeEvent { peer, event }) => { + Poll::Ready(CollectionEvent::NodeEvent { peer, event }) => { action = Default::default(); out_event = NetworkEvent::NodeEvent { conn_info: peer.info().0.clone(), event }; } @@ -1099,17 +1100,15 @@ where out_reach_attempts should always be in sync with the actual \ attempts; QED"); let mut peer = self.active_nodes.peer_mut(&peer_id).unwrap(); - if let StartTakeOver::NotReady(i) = peer.start_take_over(interrupted) { - self.take_over_to_complete = Some((peer_id, AsyncSink::NotReady(i))); - return Async::NotReady - } - if let Ok(Async::NotReady) = peer.complete_take_over() { - self.take_over_to_complete = Some((peer_id, AsyncSink::Ready)); - return Async::NotReady + if let Poll::Ready(()) = peer.poll_ready_take_over(cx) { + peer.start_take_over(interrupted); + } else { + self.take_over_to_complete = Some((peer_id, interrupted)); + return Poll::Pending } } - Async::Ready(out_event) + Poll::Ready(out_event) } } @@ -1467,7 +1466,7 @@ impl<'a, TTrans, TMuxer, TInEvent, TOutEvent, THandler, THandlerErr, TConnInfo, where TTrans: Transport + Clone, TTrans::Error: Send + 'static, - TTrans::Dial: Send + 'static, + TTrans::Dial: Unpin + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, TMuxer::OutboundSubstream: Send, TMuxer::Substream: Send, @@ -1644,18 +1643,33 @@ where closed messages; QED") } - /// Start sending an event to the node. - pub fn start_send_event(&mut self, event: TInEvent) -> StartSend { + /// Sends an event to the handler of the node. + pub fn send_event<'s: 'a>(&'s mut self, event: TInEvent) -> impl Future + 's + 'a { + let mut event = Some(event); + futures::future::poll_fn(move |cx| { + match self.poll_ready_event(cx) { + Poll::Ready(()) => { + self.start_send_event(event.take().expect("Future called after finished")); + Poll::Ready(()) + }, + Poll::Pending => Poll::Pending, + } + }) + } + + /// Begin sending an event to the node. Must be called only after a successful call to + /// `poll_ready_event`. + pub fn start_send_event(&mut self, event: TInEvent) { self.active_nodes.peer_mut(&self.peer_id) .expect("A PeerConnected is always created with a PeerId in active_nodes; QED") .start_send_event(event) } - /// Complete sending an event message, initiated by `start_send_event`. - pub fn complete_send_event(&mut self) -> Poll<(), ()> { + /// Make sure we are ready to accept an event to be sent with `start_send_event`. + pub fn poll_ready_event(&mut self, cx: &mut Context) -> Poll<()> { self.active_nodes.peer_mut(&self.peer_id) .expect("A PeerConnected is always created with a PeerId in active_nodes; QED") - .complete_send_event() + .poll_ready_event(cx) } } @@ -1749,7 +1763,7 @@ impl<'a, TTrans, TInEvent, TOutEvent, TMuxer, THandler, THandlerErr, TConnInfo, where TTrans: Transport + Clone, TTrans::Error: Send + 'static, - TTrans::Dial: Send + 'static, + TTrans::Dial: Unpin + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, TMuxer::OutboundSubstream: Send, TMuxer::Substream: Send, diff --git a/core/src/nodes/network/tests.rs b/core/src/nodes/network/tests.rs index c64666aa..c4f307bb 100644 --- a/core/src/nodes/network/tests.rs +++ b/core/src/nodes/network/tests.rs @@ -21,363 +21,6 @@ #![cfg(test)] use super::*; -use crate::tests::dummy_transport::DummyTransport; -use crate::tests::dummy_handler::{Handler, HandlerState, InEvent, OutEvent}; -use crate::tests::dummy_transport::ListenerState; -use crate::tests::dummy_muxer::{DummyMuxer, DummyConnectionState}; -use crate::nodes::NodeHandlerEvent; -use crate::transport::ListenerEvent; -use assert_matches::assert_matches; -use parking_lot::Mutex; -use std::sync::Arc; -use tokio::runtime::{Builder, Runtime}; - -#[test] -fn query_transport() { - let transport = DummyTransport::new(); - let transport2 = transport.clone(); - let network = Network::<_, _, _, Handler, _>::new(transport, PeerId::random()); - assert_eq!(network.transport(), &transport2); -} - -#[test] -fn local_node_peer() { - let peer_id = PeerId::random(); - let mut network = Network::<_, _, _, Handler, _>::new(DummyTransport::new(), peer_id.clone()); - assert_matches!(network.peer(peer_id), Peer::LocalNode); -} - -#[test] -fn successful_dial_reaches_a_node() { - let mut network = Network::<_, _, _, Handler, _>::new(DummyTransport::new(), PeerId::random()); - let addr = "/ip4/127.0.0.1/tcp/1234".parse::().expect("bad multiaddr"); - let dial_res = network.dial(addr, Handler::default()); - assert!(dial_res.is_ok()); - - // Poll the network until we get a `NodeReached` then assert on the peer: - // it's there and it's connected. - let network = Arc::new(Mutex::new(network)); - - let mut rt = Runtime::new().unwrap(); - let mut peer_id : Option = None; - // Drive forward until we're Connected - while peer_id.is_none() { - let network_fut = network.clone(); - peer_id = rt.block_on(future::poll_fn(move || -> Poll, ()> { - let mut network = network_fut.lock(); - let poll_res = network.poll(); - match poll_res { - Async::Ready(NetworkEvent::Connected { conn_info, .. }) => Ok(Async::Ready(Some(conn_info))), - _ => Ok(Async::Ready(None)) - } - })).expect("tokio works"); - } - - let mut network = network.lock(); - let peer = network.peer(peer_id.unwrap()); - assert_matches!(peer, Peer::Connected(PeerConnected{..})); -} - -#[test] -fn num_incoming_negotiated() { - let mut transport = DummyTransport::new(); - let peer_id = PeerId::random(); - let muxer = DummyMuxer::new(); - - let events = vec![ - ListenerEvent::NewAddress("/ip4/127.0.0.1/tcp/1234".parse().unwrap()), - ListenerEvent::Upgrade { - upgrade: (peer_id.clone(), muxer.clone()), - local_addr: "/ip4/127.0.0.1/tcp/1234".parse().unwrap(), - remote_addr: "/ip4/127.0.0.1/tcp/32111".parse().unwrap() - } - ]; - transport.set_initial_listener_state(ListenerState::Events(events)); - - let mut network = Network::<_, _, _, Handler, _>::new(transport, PeerId::random()); - network.listen_on("/memory/0".parse().unwrap()).unwrap(); - - // no incoming yet - assert_eq!(network.incoming_negotiated().count(), 0); - - let mut rt = Runtime::new().unwrap(); - let network = Arc::new(Mutex::new(network)); - let network_fut = network.clone(); - let fut = future::poll_fn(move || -> Poll<_, ()> { - let mut network_fut = network_fut.lock(); - assert_matches!(network_fut.poll(), Async::Ready(NetworkEvent::NewListenerAddress {..})); - assert_matches!(network_fut.poll(), Async::Ready(NetworkEvent::IncomingConnection(incoming)) => { - incoming.accept(Handler::default()); - }); - Ok(Async::Ready(())) - }); - rt.block_on(fut).expect("tokio works"); - let network = network.lock(); - // Now there's an incoming connection - assert_eq!(network.incoming_negotiated().count(), 1); -} - -#[test] -fn broadcasted_events_reach_active_nodes() { - let mut network = Network::<_, _, _, Handler, _>::new(DummyTransport::new(), PeerId::random()); - let mut muxer = DummyMuxer::new(); - muxer.set_inbound_connection_state(DummyConnectionState::Pending); - muxer.set_outbound_connection_state(DummyConnectionState::Opened); - let addr = "/ip4/127.0.0.1/tcp/1234".parse::().expect("bad multiaddr"); - let mut handler = Handler::default(); - handler.next_states = vec![HandlerState::Ready(NodeHandlerEvent::Custom(OutEvent::Custom("from handler 1") )),]; - let dial_result = network.dial(addr, handler); - assert!(dial_result.is_ok()); - - let network = Arc::new(Mutex::new(network)); - let mut rt = Runtime::new().unwrap(); - let network2 = network.clone(); - rt.block_on(future::poll_fn(move || { - if network2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - let mut peer_id : Option = None; - while peer_id.is_none() { - let network_fut = network.clone(); - peer_id = rt.block_on(future::poll_fn(move || -> Poll, ()> { - let mut network = network_fut.lock(); - if network.complete_broadcast().is_not_ready() { - return Ok(Async::NotReady) - } - let poll_res = network.poll(); - match poll_res { - Async::Ready(NetworkEvent::Connected { conn_info, .. }) => Ok(Async::Ready(Some(conn_info))), - _ => Ok(Async::Ready(None)) - } - })).expect("tokio works"); - } - - let mut keep_polling = true; - while keep_polling { - let network_fut = network.clone(); - keep_polling = rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut network = network_fut.lock(); - match network.poll() { - Async::Ready(event) => { - assert_matches!(event, NetworkEvent::NodeEvent { conn_info: _, event: inner_event } => { - // The event we sent reached the node and triggered sending the out event we told it to return - assert_matches!(inner_event, OutEvent::Custom("from handler 1")); - }); - Ok(Async::Ready(false)) - }, - _ => Ok(Async::Ready(true)) - } - })).expect("tokio works"); - } -} - -#[test] -fn querying_for_pending_peer() { - let mut network = Network::<_, _, _, Handler, _>::new(DummyTransport::new(), PeerId::random()); - let peer_id = PeerId::random(); - let peer = network.peer(peer_id.clone()); - assert_matches!(peer, Peer::NotConnected(PeerNotConnected{ .. })); - let addr = "/memory/0".parse().expect("bad multiaddr"); - let pending_peer = peer.into_not_connected().unwrap().connect(addr, Handler::default()); - assert_matches!(pending_peer, PeerPendingConnect { .. }); -} - -#[test] -fn querying_for_unknown_peer() { - let mut network = Network::<_, _, _, Handler, _>::new(DummyTransport::new(), PeerId::random()); - let peer_id = PeerId::random(); - let peer = network.peer(peer_id.clone()); - assert_matches!(peer, Peer::NotConnected( PeerNotConnected { nodes: _, peer_id: node_peer_id }) => { - assert_eq!(node_peer_id, peer_id); - }); -} - -#[test] -fn querying_for_connected_peer() { - let mut network = Network::<_, _, _, Handler, _>::new(DummyTransport::new(), PeerId::random()); - - // Dial a node - let addr = "/ip4/127.0.0.1/tcp/1234".parse().expect("bad multiaddr"); - network.dial(addr, Handler::default()).expect("dialing works"); - - let network = Arc::new(Mutex::new(network)); - let mut rt = Runtime::new().unwrap(); - // Drive it forward until we connect; extract the new PeerId. - let mut peer_id : Option = None; - while peer_id.is_none() { - let network_fut = network.clone(); - peer_id = rt.block_on(future::poll_fn(move || -> Poll, ()> { - let mut network = network_fut.lock(); - let poll_res = network.poll(); - match poll_res { - Async::Ready(NetworkEvent::Connected { conn_info, .. }) => Ok(Async::Ready(Some(conn_info))), - _ => Ok(Async::Ready(None)) - } - })).expect("tokio works"); - } - - // We're connected. - let mut network = network.lock(); - let peer = network.peer(peer_id.unwrap()); - assert_matches!(peer, Peer::Connected( PeerConnected { .. } )); -} - -#[test] -fn poll_with_closed_listener() { - let mut transport = DummyTransport::new(); - // Set up listener to be closed - transport.set_initial_listener_state(ListenerState::Ok(Async::Ready(None))); - - let mut network = Network::<_, _, _, Handler, _>::new(transport, PeerId::random()); - network.listen_on("/memory/0".parse().unwrap()).unwrap(); - - let mut rt = Runtime::new().unwrap(); - let network = Arc::new(Mutex::new(network)); - - let network_fut = network.clone(); - let fut = future::poll_fn(move || -> Poll<_, ()> { - let mut network = network_fut.lock(); - assert_matches!(network.poll(), Async::Ready(NetworkEvent::ListenerClosed { .. } )); - Ok(Async::Ready(())) - }); - rt.block_on(fut).expect("tokio works"); -} - -#[test] -fn unknown_peer_that_is_unreachable_yields_unknown_peer_dial_error() { - let mut transport = DummyTransport::new(); - transport.make_dial_fail(); - let mut network = Network::<_, _, _, Handler, _>::new(transport, PeerId::random()); - let addr = "/memory/0".parse::().expect("bad multiaddr"); - let handler = Handler::default(); - let dial_result = network.dial(addr, handler); - assert!(dial_result.is_ok()); - - let network = Arc::new(Mutex::new(network)); - let mut rt = Runtime::new().unwrap(); - // Drive it forward until we hear back from the node. - let mut keep_polling = true; - while keep_polling { - let network_fut = network.clone(); - keep_polling = rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut network = network_fut.lock(); - match network.poll() { - Async::NotReady => Ok(Async::Ready(true)), - Async::Ready(event) => { - assert_matches!(event, NetworkEvent::UnknownPeerDialError { .. } ); - Ok(Async::Ready(false)) - }, - } - })).expect("tokio works"); - } -} - -#[test] -fn known_peer_that_is_unreachable_yields_dial_error() { - let mut transport = DummyTransport::new(); - let peer_id = PeerId::random(); - transport.set_next_peer_id(&peer_id); - transport.make_dial_fail(); - let network = Arc::new(Mutex::new(Network::<_, _, _, Handler, _>::new(transport, PeerId::random()))); - - { - let network1 = network.clone(); - let mut network1 = network1.lock(); - let peer = network1.peer(peer_id.clone()); - assert_matches!(peer, Peer::NotConnected(PeerNotConnected{ .. })); - let addr = "/memory/0".parse::().expect("bad multiaddr"); - let pending_peer = peer.into_not_connected().unwrap().connect(addr, Handler::default()); - assert_matches!(pending_peer, PeerPendingConnect { .. }); - } - let mut rt = Runtime::new().unwrap(); - // Drive it forward until we hear back from the node. - let mut keep_polling = true; - while keep_polling { - let network_fut = network.clone(); - let peer_id = peer_id.clone(); - keep_polling = rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut network = network_fut.lock(); - match network.poll() { - Async::NotReady => Ok(Async::Ready(true)), - Async::Ready(event) => { - let failed_peer_id = assert_matches!( - event, - NetworkEvent::DialError { new_state: _, peer_id: failed_peer_id, .. } => failed_peer_id - ); - assert_eq!(peer_id, failed_peer_id); - Ok(Async::Ready(false)) - }, - } - })).expect("tokio works"); - } -} - -#[test] -fn yields_node_error_when_there_is_an_error_after_successful_connect() { - let mut transport = DummyTransport::new(); - let peer_id = PeerId::random(); - transport.set_next_peer_id(&peer_id); - let network = Arc::new(Mutex::new(Network::<_, _, _, Handler, _>::new(transport, PeerId::random()))); - - { - // Set up an outgoing connection with a PeerId we know - let network1 = network.clone(); - let mut network1 = network1.lock(); - let peer = network1.peer(peer_id.clone()); - let addr = "/unix/reachable".parse().expect("bad multiaddr"); - let mut handler = Handler::default(); - // Force an error - handler.next_states = vec![ HandlerState::Err ]; - peer.into_not_connected().unwrap().connect(addr, handler); - } - - // Ensure we run on a single thread - let mut rt = Builder::new().core_threads(1).build().unwrap(); - - // Drive it forward until we connect to the node. - let mut keep_polling = true; - while keep_polling { - let network_fut = network.clone(); - let network2 = network.clone(); - rt.block_on(future::poll_fn(move || { - if network2.lock().start_broadcast(&InEvent::NextState).is_not_ready() { - Ok::<_, ()>(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - })).unwrap(); - keep_polling = rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut network = network_fut.lock(); - // Push the Handler into an error state on the next poll - if network.complete_broadcast().is_not_ready() { - return Ok(Async::NotReady) - } - match network.poll() { - Async::NotReady => Ok(Async::Ready(true)), - Async::Ready(event) => { - assert_matches!(event, NetworkEvent::Connected { .. }); - // We're connected, we can move on - Ok(Async::Ready(false)) - }, - } - })).expect("tokio works"); - } - - // Poll again. It is going to be a NodeClosed because of how the - // handler's next state was set up. - let network_fut = network.clone(); - let expected_peer_id = peer_id.clone(); - rt.block_on(future::poll_fn(move || -> Poll<_, ()> { - let mut network = network_fut.lock(); - assert_matches!(network.poll(), Async::Ready(NetworkEvent::NodeClosed { conn_info, .. }) => { - assert_eq!(conn_info, expected_peer_id); - }); - Ok(Async::Ready(())) - })).expect("tokio works"); -} #[test] fn local_prio_equivalence_relation() { @@ -387,59 +30,3 @@ fn local_prio_equivalence_relation() { assert_ne!(has_dial_prio(&a, &b), has_dial_prio(&b, &a)); } } - -#[test] -fn limit_incoming_connections() { - let mut transport = DummyTransport::new(); - let peer_id = PeerId::random(); - let muxer = DummyMuxer::new(); - let limit = 1; - - let mut events = vec![ListenerEvent::NewAddress("/ip4/127.0.0.1/tcp/1234".parse().unwrap())]; - events.extend(std::iter::repeat( - ListenerEvent::Upgrade { - upgrade: (peer_id.clone(), muxer.clone()), - local_addr: "/ip4/127.0.0.1/tcp/1234".parse().unwrap(), - remote_addr: "/ip4/127.0.0.1/tcp/32111".parse().unwrap() - } - ).take(10)); - transport.set_initial_listener_state(ListenerState::Events(events)); - - let mut network = Network::<_, _, _, Handler, _>::new_with_incoming_limit(transport, PeerId::random(), Some(limit)); - assert_eq!(network.incoming_limit(), Some(limit)); - network.listen_on("/memory/0".parse().unwrap()).unwrap(); - assert_eq!(network.incoming_negotiated().count(), 0); - - let network = Arc::new(Mutex::new(network)); - let mut rt = Runtime::new().unwrap(); - for i in 1..10 { - let network_fut = network.clone(); - let fut = future::poll_fn(move || -> Poll<_, ()> { - let mut network_fut = network_fut.lock(); - if i <= limit { - assert_matches!(network_fut.poll(), Async::Ready(NetworkEvent::NewListenerAddress {..})); - assert_matches!(network_fut.poll(), - Async::Ready(NetworkEvent::IncomingConnection(incoming)) => { - incoming.accept(Handler::default()); - }); - } else { - match network_fut.poll() { - Async::NotReady => (), - Async::Ready(x) => { - match x { - NetworkEvent::NewListenerAddress {..} => {} - NetworkEvent::ExpiredListenerAddress {..} => {} - NetworkEvent::IncomingConnection(_) => {} - NetworkEvent::Connected {..} => {} - e => panic!("Not expected event: {:?}", e) - } - }, - } - } - Ok(Async::Ready(())) - }); - rt.block_on(fut).expect("tokio works"); - let network = network.lock(); - assert!(network.incoming_negotiated().count() <= (limit as usize)); - } -} diff --git a/core/src/nodes/node.rs b/core/src/nodes/node.rs index a1d0eac4..37da9954 100644 --- a/core/src/nodes/node.rs +++ b/core/src/nodes/node.rs @@ -21,9 +21,7 @@ use futures::prelude::*; use crate::muxing; use smallvec::SmallVec; -use std::fmt; -use std::io::Error as IoError; -use std::sync::Arc; +use std::{fmt, io::Error as IoError, pin::Pin, sync::Arc, task::Context, task::Poll}; // Implementation notes // ================= @@ -143,43 +141,44 @@ where } /// Provides an API similar to `Future`. - pub fn poll(&mut self) -> Poll, IoError> { + pub fn poll(&mut self, cx: &mut Context) -> Poll, IoError>> { // Polling inbound substream. - match self.muxer.poll_inbound().map_err(|e| e.into())? { - Async::Ready(substream) => { + match self.muxer.poll_inbound(cx) { + Poll::Ready(Ok(substream)) => { let substream = muxing::substream_from_ref(self.muxer.clone(), substream); - return Ok(Async::Ready(NodeEvent::InboundSubstream { + return Poll::Ready(Ok(NodeEvent::InboundSubstream { substream, })); } - Async::NotReady => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => {} } // Polling outbound substreams. // We remove each element from `outbound_substreams` one by one and add them back. for n in (0..self.outbound_substreams.len()).rev() { let (user_data, mut outbound) = self.outbound_substreams.swap_remove(n); - match self.muxer.poll_outbound(&mut outbound) { - Ok(Async::Ready(substream)) => { + match self.muxer.poll_outbound(cx, &mut outbound) { + Poll::Ready(Ok(substream)) => { let substream = muxing::substream_from_ref(self.muxer.clone(), substream); self.muxer.destroy_outbound(outbound); - return Ok(Async::Ready(NodeEvent::OutboundSubstream { + return Poll::Ready(Ok(NodeEvent::OutboundSubstream { user_data, substream, })); } - Ok(Async::NotReady) => { + Poll::Pending => { self.outbound_substreams.push((user_data, outbound)); } - Err(err) => { + Poll::Ready(Err(err)) => { self.muxer.destroy_outbound(outbound); - return Err(err.into()); + return Poll::Ready(Err(err.into())); } } } // Nothing happened. Register our task to be notified and return. - Ok(Async::NotReady) + Poll::Pending } } @@ -212,11 +211,14 @@ impl Future for Close where TMuxer: muxing::StreamMuxer, { - type Item = (); - type Error = IoError; + type Output = Result<(), IoError>; - fn poll(&mut self) -> Poll { - self.muxer.close().map_err(|e| e.into()) + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.muxer.close(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), + } } } @@ -252,70 +254,3 @@ where } } } - -#[cfg(test)] -mod node_stream { - use super::{NodeEvent, NodeStream}; - use crate::tests::dummy_muxer::{DummyMuxer, DummyConnectionState}; - use assert_matches::assert_matches; - use futures::prelude::*; - use tokio_mock_task::MockTask; - - fn build_node_stream() -> NodeStream> { - let muxer = DummyMuxer::new(); - NodeStream::<_, Vec>::new(muxer) - } - - #[test] - fn closing_a_node_stream_destroys_substreams_and_returns_submitted_user_data() { - let mut ns = build_node_stream(); - ns.open_substream(vec![2]); - ns.open_substream(vec![3]); - ns.open_substream(vec![5]); - let user_data_submitted = ns.close(); - assert_eq!(user_data_submitted.1, vec![ - vec![2], vec![3], vec![5] - ]); - } - - #[test] - fn poll_returns_not_ready_when_there_is_nothing_to_do() { - let mut task = MockTask::new(); - task.enter(|| { - // ensure the address never resolves - let mut muxer = DummyMuxer::new(); - // ensure muxer.poll_inbound() returns Async::NotReady - muxer.set_inbound_connection_state(DummyConnectionState::Pending); - // ensure muxer.poll_outbound() returns Async::NotReady - muxer.set_outbound_connection_state(DummyConnectionState::Pending); - let mut ns = NodeStream::<_, Vec>::new(muxer); - - assert_matches!(ns.poll(), Ok(Async::NotReady)); - }); - } - - #[test] - fn poll_keeps_outbound_substreams_when_the_outgoing_connection_is_not_ready() { - let mut muxer = DummyMuxer::new(); - // ensure muxer.poll_inbound() returns Async::NotReady - muxer.set_inbound_connection_state(DummyConnectionState::Pending); - // ensure muxer.poll_outbound() returns Async::NotReady - muxer.set_outbound_connection_state(DummyConnectionState::Pending); - let mut ns = NodeStream::<_, Vec>::new(muxer); - ns.open_substream(vec![1]); - ns.poll().unwrap(); // poll past inbound - ns.poll().unwrap(); // poll outbound - assert!(format!("{:?}", ns).contains("outbound_substreams: 1")); - } - - #[test] - fn poll_returns_incoming_substream() { - let mut muxer = DummyMuxer::new(); - // ensure muxer.poll_inbound() returns Async::Ready(subs) - muxer.set_inbound_connection_state(DummyConnectionState::Opened); - let mut ns = NodeStream::<_, Vec>::new(muxer); - assert_matches!(ns.poll(), Ok(Async::Ready(node_event)) => { - assert_matches!(node_event, NodeEvent::InboundSubstream{ substream: _ }); - }); - } -} diff --git a/core/src/nodes/tasks/manager.rs b/core/src/nodes/tasks/manager.rs index 33643aaa..aff72bd9 100644 --- a/core/src/nodes/tasks/manager.rs +++ b/core/src/nodes/tasks/manager.rs @@ -27,9 +27,8 @@ use crate::{ } }; use fnv::FnvHashMap; -use futures::{prelude::*, future::Executor, sync::mpsc}; -use smallvec::SmallVec; -use std::{collections::hash_map::{Entry, OccupiedEntry}, error, fmt}; +use futures::{prelude::*, channel::mpsc, executor::ThreadPool, stream::FuturesUnordered, task::SpawnExt as _}; +use std::{collections::hash_map::{Entry, OccupiedEntry}, error, fmt, pin::Pin, task::Context, task::Poll}; use super::{TaskId, task::{Task, FromTaskMessage, ToTaskMessage}, Error}; // Implementor notes @@ -64,12 +63,13 @@ pub struct Manager { /// Identifier for the next task to spawn. next_task_id: TaskId, - /// List of node tasks to spawn. - to_spawn: SmallVec<[Box + Send>; 8]>, + /// Threads pool where we spawn the nodes' tasks. If `None`, then we push tasks to the + /// `local_spawns` list instead. + threads_pool: Option, - /// If no tokio executor is available, we move tasks to this list, and futures are polled on - /// the current thread instead. - local_spawns: Vec + Send>>, + /// If no executor is available, we move tasks to this list, and futures are polled on the + /// current thread instead. + local_spawns: FuturesUnordered + Send>>>, /// Sender to emit events to the outside. Meant to be cloned and sent to tasks. events_tx: mpsc::Sender<(FromTaskMessage, TaskId)>, @@ -91,16 +91,13 @@ where /// Information about a running task. /// -/// Contains the sender to deliver event messages to the task, -/// the associated user data and a pending message if any, -/// meant to be delivered to the task via the sender. +/// Contains the sender to deliver event messages to the task, and +/// the associated user data. struct TaskInfo { /// channel endpoint to send messages to the task sender: mpsc::Sender>, /// task associated data user_data: T, - /// any pending event to deliver to the task - pending: Option>> } /// Event produced by the [`Manager`]. @@ -140,11 +137,15 @@ impl Manager { /// Creates a new task manager. pub fn new() -> Self { let (tx, rx) = mpsc::channel(1); + let threads_pool = ThreadPool::builder() + .name_prefix("libp2p-nodes-") + .create().ok(); + Self { tasks: FnvHashMap::default(), next_task_id: TaskId(0), - to_spawn: SmallVec::new(), - local_spawns: Vec::new(), + threads_pool, + local_spawns: FuturesUnordered::new(), events_tx: tx, events_rx: rx } @@ -156,7 +157,7 @@ impl Manager { /// processing the node's events. pub fn add_reach_attempt(&mut self, future: F, user_data: T, handler: H) -> TaskId where - F: Future + Send + 'static, + F: Future> + Unpin + Send + 'static, H: IntoNodeHandler + Send + 'static, H::Handler: NodeHandler, InEvent = I, OutEvent = O, Error = HE> + Send + 'static, E: error::Error + Send + 'static, @@ -172,10 +173,14 @@ impl Manager { self.next_task_id.0 += 1; let (tx, rx) = mpsc::channel(4); - self.tasks.insert(task_id, TaskInfo { sender: tx, user_data, pending: None }); + self.tasks.insert(task_id, TaskInfo { sender: tx, user_data }); - let task = Box::new(Task::new(task_id, self.events_tx.clone(), rx, future, handler)); - self.to_spawn.push(task); + let task = Box::pin(Task::new(task_id, self.events_tx.clone(), rx, future, handler)); + if let Some(threads_pool) = &mut self.threads_pool { + threads_pool.spawn(task).expect("spawning a task on a threads pool never fails; qed"); + } else { + self.local_spawns.push(task); + } task_id } @@ -202,71 +207,56 @@ impl Manager { self.next_task_id.0 += 1; let (tx, rx) = mpsc::channel(4); - self.tasks.insert(task_id, TaskInfo { sender: tx, user_data, pending: None }); + self.tasks.insert(task_id, TaskInfo { sender: tx, user_data }); - let task: Task, _, _, _, _, _, _> = + // TODO: we use `Pin>` instead of just `Pending` because `Pending` doesn't + // implement `Unpin` even though it should ; this is just a dummy template parameter and + // the `Box` is never actually created, so this has no repercusion whatsoever + // see https://github.com/rust-lang-nursery/futures-rs/pull/1746 + let task: Task>>, _, _, _, _, _, _> = Task::node(task_id, self.events_tx.clone(), rx, HandledNode::new(muxer, handler)); - self.to_spawn.push(Box::new(task)); + if let Some(threads_pool) = &mut self.threads_pool { + threads_pool.spawn(Box::pin(task)).expect("spawning a task on a threads pool never fails; qed"); + } else { + self.local_spawns.push(Box::pin(task)); + } + task_id } /// Start sending an event to all the tasks, including the pending ones. /// + /// Must be called only after a successful call to `poll_ready_broadcast`. + /// /// After starting a broadcast make sure to finish it with `complete_broadcast`, /// otherwise starting another broadcast or sending an event directly to a /// task would overwrite the pending broadcast. #[must_use] - pub fn start_broadcast(&mut self, event: &I) -> AsyncSink<()> + pub fn start_broadcast(&mut self, event: &I) where I: Clone { - if self.complete_broadcast().is_not_ready() { - return AsyncSink::NotReady(()) - } - for task in self.tasks.values_mut() { let msg = ToTaskMessage::HandlerEvent(event.clone()); - task.pending = Some(AsyncSink::NotReady(msg)) + match task.sender.start_send(msg) { + Ok(()) => {}, + Err(ref err) if err.is_full() => {}, // TODO: somehow report to user? + Err(_) => {}, + } } - - AsyncSink::Ready } - /// Complete a started broadcast. + /// Wait until we have enough room in senders to broadcast an event. #[must_use] - pub fn complete_broadcast(&mut self) -> Async<()> { - let mut ready = true; - + pub fn poll_ready_broadcast(&mut self, cx: &mut Context) -> Poll<()> { for task in self.tasks.values_mut() { - match task.pending.take() { - Some(AsyncSink::NotReady(msg)) => - match task.sender.start_send(msg) { - Ok(AsyncSink::NotReady(msg)) => { - task.pending = Some(AsyncSink::NotReady(msg)); - ready = false - } - Ok(AsyncSink::Ready) => - if let Ok(Async::NotReady) = task.sender.poll_complete() { - task.pending = Some(AsyncSink::Ready); - ready = false - } - Err(_) => {} - } - Some(AsyncSink::Ready) => - if let Ok(Async::NotReady) = task.sender.poll_complete() { - task.pending = Some(AsyncSink::Ready); - ready = false - } - None => {} + if let Poll::Pending = task.sender.poll_ready(cx) { + return Poll::Pending; } } - if ready { - Async::Ready(()) - } else { - Async::NotReady - } + Poll::Ready(()) } /// Grants access to an object that allows controlling a task of the collection. @@ -285,32 +275,13 @@ impl Manager { } /// Provides an API similar to `Stream`, except that it cannot produce an error. - pub fn poll(&mut self) -> Async> { - for to_spawn in self.to_spawn.drain() { - // We try to use the default executor, but fall back to polling the task manually if - // no executor is available. This makes it possible to use the core in environments - // outside of tokio. - let executor = tokio_executor::DefaultExecutor::current(); - if let Err(err) = executor.execute(to_spawn) { - self.local_spawns.push(err.into_future()) - } - } - - for n in (0 .. self.local_spawns.len()).rev() { - let mut task = self.local_spawns.swap_remove(n); - match task.poll() { - Ok(Async::Ready(())) => {} - Ok(Async::NotReady) => self.local_spawns.push(task), - // It would normally be desirable to either report or log when a background task - // errors. However the default tokio executor doesn't do anything in case of error, - // and therefore we mimic this behaviour by also not doing anything. - Err(()) => {} - } - } + pub fn poll(&mut self, cx: &mut Context) -> Poll> { + // Advance the content of `local_spawns`. + while let Poll::Ready(Some(_)) = Stream::poll_next(Pin::new(&mut self.local_spawns), cx) {} let (message, task_id) = loop { - match self.events_rx.poll() { - Ok(Async::Ready(Some((message, task_id)))) => { + match Stream::poll_next(Pin::new(&mut self.events_rx), cx) { + Poll::Ready(Some((message, task_id))) => { // If the task id is no longer in `self.tasks`, that means that the user called // `close()` on this task earlier. Therefore no new event should be generated // for this task. @@ -318,13 +289,12 @@ impl Manager { break (message, task_id) } } - Ok(Async::NotReady) => return Async::NotReady, - Ok(Async::Ready(None)) => unreachable!("sender and receiver have same lifetime"), - Err(()) => unreachable!("An `mpsc::Receiver` does not error.") + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => unreachable!("sender and receiver have same lifetime"), } }; - Async::Ready(match message { + Poll::Ready(match message { FromTaskMessage::NodeEvent(event) => Event::NodeEvent { task: match self.tasks.entry(task_id) { @@ -360,24 +330,16 @@ pub struct TaskEntry<'a, E, T> { } impl<'a, E, T> TaskEntry<'a, E, T> { - /// Begin sending an event to the given node. - /// - /// Make sure to finish the send operation with `complete_send_event`. - pub fn start_send_event(&mut self, event: E) -> StartSend { + /// Begin sending an event to the given node. Must be called only after a successful call to + /// `poll_ready_event`. + pub fn start_send_event(&mut self, event: E) { let msg = ToTaskMessage::HandlerEvent(event); - if let AsyncSink::NotReady(msg) = self.start_send_event_msg(msg)? { - if let ToTaskMessage::HandlerEvent(event) = msg { - return Ok(AsyncSink::NotReady(event)) - } else { - unreachable!("we tried to send an handler event, so we get one back if not ready") - } - } - Ok(AsyncSink::Ready) + self.start_send_event_msg(msg); } - /// Finish a send operation started with `start_send_event`. - pub fn complete_send_event(&mut self) -> Poll<(), ()> { - self.complete_send_event_msg() + /// Make sure we are ready to accept an event to be sent with `start_send_event`. + pub fn poll_ready_event(&mut self, cx: &mut Context) -> Poll<()> { + self.poll_ready_event_msg(cx) } /// Returns the user data associated with the task. @@ -409,79 +371,38 @@ impl<'a, E, T> TaskEntry<'a, E, T> { /// As soon as our task (`self`) has some acknowledgment from the remote /// that its connection is alive, it will close the connection with `other`. /// - /// Make sure to complete this operation with `complete_take_over`. - #[must_use] - pub fn start_take_over(&mut self, t: ClosedTask) -> StartTakeOver> { + /// Must be called only after a successful call to `poll_ready_take_over`. + pub fn start_take_over(&mut self, t: ClosedTask) { + self.start_send_event_msg(ToTaskMessage::TakeOver(t.sender)); + } + + /// Make sure we are ready to taking over with `start_take_over`. + pub fn poll_ready_take_over(&mut self, cx: &mut Context) -> Poll<()> { + self.poll_ready_event_msg(cx) + } + + /// Sends a message to the task. Must be called only after a successful call to + /// `poll_ready_event`. + /// + /// The API mimicks the one of [`futures::Sink`]. + fn start_send_event_msg(&mut self, msg: ToTaskMessage) { // It is possible that the sender is closed if the background task has already finished // but the local state hasn't been updated yet because we haven't been polled in the // meanwhile. - let id = t.id(); - match self.start_send_event_msg(ToTaskMessage::TakeOver(t.sender)) { - Ok(AsyncSink::Ready) => StartTakeOver::Ready(t.user_data), - Ok(AsyncSink::NotReady(ToTaskMessage::TakeOver(sender))) => - StartTakeOver::NotReady(ClosedTask::new(id, sender, t.user_data)), - Ok(AsyncSink::NotReady(_)) => - unreachable!("We tried to send a take over message, so we get one back."), - Err(()) => StartTakeOver::Gone + match self.inner.get_mut().sender.start_send(msg) { + Ok(()) => {}, + Err(ref err) if err.is_full() => {}, // TODO: somehow report to user? + Err(_) => {}, } } - /// Finish take over started by `start_take_over`. - pub fn complete_take_over(&mut self) -> Poll<(), ()> { - self.complete_send_event_msg() - } - - /// Begin to send a message to the task. - /// - /// The API mimicks the one of [`futures::Sink`]. If this method returns - /// `Ok(AsyncSink::Ready)` drive the sending to completion with - /// `complete_send_event_msg`. If the receiving end does not longer exist, - /// i.e. the task has ended, we return this information as an error. - fn start_send_event_msg(&mut self, msg: ToTaskMessage) -> StartSend, ()> { - // We first drive any pending send to completion before starting another one. - if self.complete_send_event_msg()?.is_ready() { - self.inner.get_mut().pending = Some(AsyncSink::NotReady(msg)); - Ok(AsyncSink::Ready) - } else { - Ok(AsyncSink::NotReady(msg)) - } - } - - /// Complete event message deliver started by `start_send_event_msg`. - fn complete_send_event_msg(&mut self) -> Poll<(), ()> { + /// Wait until we have space to send an event using `start_send_event_msg`. + fn poll_ready_event_msg(&mut self, cx: &mut Context) -> Poll<()> { // It is possible that the sender is closed if the background task has already finished // but the local state hasn't been updated yet because we haven't been polled in the // meanwhile. let task = self.inner.get_mut(); - let state = - if let Some(state) = task.pending.take() { - state - } else { - return Ok(Async::Ready(())) - }; - match state { - AsyncSink::NotReady(msg) => - match task.sender.start_send(msg).map_err(|_| ())? { - AsyncSink::Ready => - if task.sender.poll_complete().map_err(|_| ())?.is_not_ready() { - task.pending = Some(AsyncSink::Ready); - Ok(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - AsyncSink::NotReady(msg) => { - task.pending = Some(AsyncSink::NotReady(msg)); - Ok(Async::NotReady) - } - } - AsyncSink::Ready => - if task.sender.poll_complete().map_err(|_| ())?.is_not_ready() { - task.pending = Some(AsyncSink::Ready); - Ok(Async::NotReady) - } else { - Ok(Async::Ready(())) - } - } + task.sender.poll_ready(cx).map(|_| ()) } } @@ -494,18 +415,6 @@ impl fmt::Debug for TaskEntry<'_, E, T> { } } -/// Result of [`TaskEntry::start_take_over`]. -#[derive(Debug)] -pub enum StartTakeOver { - /// The take over message has been enqueued. - /// Complete the take over with [`TaskEntry::complete_take_over`]. - Ready(A), - /// Not ready to send the take over message to the task. - NotReady(B), - /// The task to send the take over message is no longer there. - Gone -} - /// Task after it has been closed. /// /// The connection to the remote is potentially still going on, but no new diff --git a/core/src/nodes/tasks/mod.rs b/core/src/nodes/tasks/mod.rs index baa1a081..2af4939c 100644 --- a/core/src/nodes/tasks/mod.rs +++ b/core/src/nodes/tasks/mod.rs @@ -37,7 +37,7 @@ mod manager; mod task; pub use error::Error; -pub use manager::{ClosedTask, TaskEntry, Manager, Event, StartTakeOver}; +pub use manager::{ClosedTask, TaskEntry, Manager, Event}; /// Task identifier. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] diff --git a/core/src/nodes/tasks/task.rs b/core/src/nodes/tasks/task.rs index 05b801e1..992a59bf 100644 --- a/core/src/nodes/tasks/task.rs +++ b/core/src/nodes/tasks/task.rs @@ -25,8 +25,9 @@ use crate::{ node::{Close, Substream} } }; -use futures::{prelude::*, stream, sync::mpsc}; +use futures::{prelude::*, channel::mpsc, stream}; use smallvec::SmallVec; +use std::{pin::Pin, task::Context, task::Poll}; use super::{TaskId, Error}; /// Message to transmit from the public API to a task. @@ -140,13 +141,6 @@ where event: FromTaskMessage::Error, C> }, - /// We started sending an event, now drive the sending to completion. - /// - /// The `bool` parameter determines if we transition to `State::Node` - /// afterwards or to `State::Closing` (assuming we have `Some` node, - /// otherwise the task will end). - PollComplete(Option>, bool), - /// Fully functional node. Node(HandledNode), @@ -158,94 +152,103 @@ where Undefined } +impl Unpin for Task +where + M: StreamMuxer, + H: IntoNodeHandler, + H::Handler: NodeHandler> +{ +} + impl Future for Task where M: StreamMuxer, - F: Future, + F: Future> + Unpin, H: IntoNodeHandler, H::Handler: NodeHandler, InEvent = I, OutEvent = O> { - type Item = (); - type Error = (); + type Output = (); // NOTE: It is imperative to always consume all incoming event messages // first in order to not prevent the outside from making progress because // they are blocked on the channel capacity. - fn poll(&mut self) -> Poll<(), ()> { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + // We use a `this` because the compiler isn't smart enough to allow mutably borrowing + // multiple different fields from the `Pin` at the same time. + let this = &mut *self; + 'poll: loop { - match std::mem::replace(&mut self.state, State::Undefined) { + match std::mem::replace(&mut this.state, State::Undefined) { State::Future { mut future, handler, mut events_buffer } => { - // If self.receiver is closed, we stop the task. + // If this.receiver is closed, we stop the task. loop { - match self.receiver.poll() { - Ok(Async::NotReady) => break, - Ok(Async::Ready(None)) => return Ok(Async::Ready(())), - Ok(Async::Ready(Some(ToTaskMessage::HandlerEvent(event)))) => + match Stream::poll_next(Pin::new(&mut this.receiver), cx) { + Poll::Pending => break, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Ready(Some(ToTaskMessage::HandlerEvent(event))) => events_buffer.push(event), - Ok(Async::Ready(Some(ToTaskMessage::TakeOver(take_over)))) => - self.taken_over.push(take_over), - Err(()) => unreachable!("An `mpsc::Receiver` does not error.") + Poll::Ready(Some(ToTaskMessage::TakeOver(take_over))) => + this.taken_over.push(take_over), } } // Check if dialing succeeded. - match future.poll() { - Ok(Async::Ready((conn_info, muxer))) => { + match Future::poll(Pin::new(&mut future), cx) { + Poll::Ready(Ok((conn_info, muxer))) => { let mut node = HandledNode::new(muxer, handler.into_handler(&conn_info)); for event in events_buffer { node.inject_event(event) } - self.state = State::SendEvent { + this.state = State::SendEvent { node: Some(node), event: FromTaskMessage::NodeReached(conn_info) } } - Ok(Async::NotReady) => { - self.state = State::Future { future, handler, events_buffer }; - return Ok(Async::NotReady) + Poll::Pending => { + this.state = State::Future { future, handler, events_buffer }; + return Poll::Pending } - Err(e) => { + Poll::Ready(Err(e)) => { let event = FromTaskMessage::TaskClosed(Error::Reach(e), Some(handler)); - self.state = State::SendEvent { node: None, event } + this.state = State::SendEvent { node: None, event } } } } State::Node(mut node) => { // Start by handling commands received from the outside of the task. loop { - match self.receiver.poll() { - Ok(Async::NotReady) => break, - Ok(Async::Ready(Some(ToTaskMessage::HandlerEvent(event)))) => + match Stream::poll_next(Pin::new(&mut this.receiver), cx) { + Poll::Pending => break, + Poll::Ready(Some(ToTaskMessage::HandlerEvent(event))) => node.inject_event(event), - Ok(Async::Ready(Some(ToTaskMessage::TakeOver(take_over)))) => - self.taken_over.push(take_over), - Ok(Async::Ready(None)) => { + Poll::Ready(Some(ToTaskMessage::TakeOver(take_over))) => + this.taken_over.push(take_over), + Poll::Ready(None) => { // Node closed by the external API; start closing. - self.state = State::Closing(node.close()); + this.state = State::Closing(node.close()); continue 'poll } - Err(()) => unreachable!("An `mpsc::Receiver` does not error.") } } // Process the node. loop { - if !self.taken_over.is_empty() && node.is_remote_acknowledged() { - self.taken_over.clear() + if !this.taken_over.is_empty() && node.is_remote_acknowledged() { + this.taken_over.clear() } - match node.poll() { - Ok(Async::NotReady) => { - self.state = State::Node(node); - return Ok(Async::NotReady) + match HandledNode::poll(Pin::new(&mut node), cx) { + Poll::Pending => { + this.state = State::Node(node); + return Poll::Pending } - Ok(Async::Ready(event)) => { - self.state = State::SendEvent { + Poll::Ready(Ok(event)) => { + this.state = State::SendEvent { node: Some(node), event: FromTaskMessage::NodeEvent(event) }; continue 'poll } - Err(err) => { + Poll::Ready(Err(err)) => { let event = FromTaskMessage::TaskClosed(Error::Node(err), None); - self.state = State::SendEvent { node: None, event }; + this.state = State::SendEvent { node: None, event }; continue 'poll } } @@ -254,23 +257,22 @@ where // Deliver an event to the outside. State::SendEvent { mut node, event } => { loop { - match self.receiver.poll() { - Ok(Async::NotReady) => break, - Ok(Async::Ready(Some(ToTaskMessage::HandlerEvent(event)))) => + match Stream::poll_next(Pin::new(&mut this.receiver), cx) { + Poll::Pending => break, + Poll::Ready(Some(ToTaskMessage::HandlerEvent(event))) => if let Some(ref mut n) = node { n.inject_event(event) } - Ok(Async::Ready(Some(ToTaskMessage::TakeOver(take_over)))) => - self.taken_over.push(take_over), - Ok(Async::Ready(None)) => + Poll::Ready(Some(ToTaskMessage::TakeOver(take_over))) => + this.taken_over.push(take_over), + Poll::Ready(None) => // Node closed by the external API; start closing. if let Some(n) = node { - self.state = State::Closing(n.close()); + this.state = State::Closing(n.close()); continue 'poll } else { - return Ok(Async::Ready(())) // end task + return Poll::Ready(()) // end task } - Err(()) => unreachable!("An `mpsc::Receiver` does not error.") } } // Check if this task is about to close. We pass the flag to @@ -281,80 +283,46 @@ where } else { false }; - match self.sender.start_send((event, self.id)) { - Ok(AsyncSink::NotReady((event, _))) => { + match this.sender.poll_ready(cx) { + Poll::Pending => { self.state = State::SendEvent { node, event }; - return Ok(Async::NotReady) + return Poll::Pending } - Ok(AsyncSink::Ready) => self.state = State::PollComplete(node, close), - Err(_) => { - if let Some(n) = node { - self.state = State::Closing(n.close()); - continue 'poll - } - // We can not communicate to the outside and there is no - // node to handle, so this is the end of this task. - return Ok(Async::Ready(())) - } - } - } - // We started delivering an event, now try to complete the sending. - State::PollComplete(mut node, close) => { - loop { - match self.receiver.poll() { - Ok(Async::NotReady) => break, - Ok(Async::Ready(Some(ToTaskMessage::HandlerEvent(event)))) => - if let Some(ref mut n) = node { - n.inject_event(event) - } - Ok(Async::Ready(Some(ToTaskMessage::TakeOver(take_over)))) => - self.taken_over.push(take_over), - Ok(Async::Ready(None)) => - // Node closed by the external API; start closing. - if let Some(n) = node { - self.state = State::Closing(n.close()); - continue 'poll - } else { - return Ok(Async::Ready(())) // end task - } - Err(()) => unreachable!("An `mpsc::Receiver` does not error.") - } - } - match self.sender.poll_complete() { - Ok(Async::NotReady) => { - self.state = State::PollComplete(node, close); - return Ok(Async::NotReady) - } - Ok(Async::Ready(())) => + Poll::Ready(Ok(())) => { + // We assume that if `poll_ready` has succeeded, then sending the event + // will succeed as well. If it turns out that it didn't, we will detect + // the closing at the next loop iteration. + let _ = this.sender.start_send((event, this.id)); if let Some(n) = node { if close { - self.state = State::Closing(n.close()) + this.state = State::Closing(n.close()) } else { - self.state = State::Node(n) + this.state = State::Node(n) } } else { // Since we have no node we terminate this task. assert!(close); - return Ok(Async::Ready(())) + return Poll::Ready(()) } - Err(_) => { + }, + Poll::Ready(Err(_)) => { if let Some(n) = node { - self.state = State::Closing(n.close()); + this.state = State::Closing(n.close()); continue 'poll } // We can not communicate to the outside and there is no // node to handle, so this is the end of this task. - return Ok(Async::Ready(())) + return Poll::Ready(()) } } } State::Closing(mut closing) => - match closing.poll() { - Ok(Async::Ready(())) | Err(_) => - return Ok(Async::Ready(())), // end task - Ok(Async::NotReady) => { - self.state = State::Closing(closing); - return Ok(Async::NotReady) + match Future::poll(Pin::new(&mut closing), cx) { + Poll::Ready(_) => + return Poll::Ready(()), // end task + Poll::Pending => { + this.state = State::Closing(closing); + return Poll::Pending } } // This happens if a previous poll has resolved the future. diff --git a/core/src/tests/dummy_handler.rs b/core/src/tests/dummy_handler.rs deleted file mode 100644 index 2f4ee3fa..00000000 --- a/core/src/tests/dummy_handler.rs +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Concrete `NodeHandler` implementation and assorted testing types - -use std::io::{self, Error as IoError}; - -use super::dummy_muxer::DummyMuxer; -use futures::prelude::*; -use crate::muxing::SubstreamRef; -use crate::nodes::handled_node::{HandledNode, NodeHandler, NodeHandlerEndpoint, NodeHandlerEvent}; -use std::sync::Arc; - -#[derive(Debug, PartialEq, Clone)] -pub(crate) struct Handler { - /// Inspect events passed through the Handler - pub events: Vec, - /// Current state of the Handler - pub state: Option, - /// Next state for outbound streams of the Handler - pub next_outbound_state: Option, - /// Vec of states the Handler will assume - pub next_states: Vec, -} - -impl Default for Handler { - fn default() -> Self { - Handler { - events: Vec::new(), - state: None, - next_states: Vec::new(), - next_outbound_state: None, - } - } -} - -#[derive(Debug, PartialEq, Clone)] -pub(crate) enum HandlerState { - Ready(NodeHandlerEvent), - Err, -} - -#[derive(Debug, PartialEq, Clone)] -pub(crate) enum InEvent { - /// A custom inbound event - Custom(&'static str), - /// A substream request with a dummy payload - Substream(Option), - /// Request the handler to move to the next state - NextState, -} - -#[derive(Debug, PartialEq, Clone)] -pub(crate) enum OutEvent { - /// A message from the Handler upwards in the stack - Custom(&'static str), -} - -// Concrete `HandledNode` parametrised for the test helpers -pub(crate) type TestHandledNode = HandledNode; - -impl NodeHandler for Handler { - type InEvent = InEvent; - type OutEvent = OutEvent; - type Error = IoError; - type OutboundOpenInfo = usize; - type Substream = SubstreamRef>; - fn inject_substream( - &mut self, - _: Self::Substream, - endpoint: NodeHandlerEndpoint, - ) { - let user_data = match endpoint { - NodeHandlerEndpoint::Dialer(user_data) => Some(user_data), - NodeHandlerEndpoint::Listener => None, - }; - self.events.push(InEvent::Substream(user_data)); - } - fn inject_event(&mut self, inevent: Self::InEvent) { - self.events.push(inevent.clone()); - match inevent { - InEvent::Custom(s) => { - self.state = Some(HandlerState::Ready(NodeHandlerEvent::Custom( - OutEvent::Custom(s), - ))) - } - InEvent::Substream(Some(user_data)) => { - self.state = Some(HandlerState::Ready( - NodeHandlerEvent::OutboundSubstreamRequest(user_data), - )) - } - InEvent::NextState => { - let next_state = self.next_states.pop(); - self.state = next_state - } - _ => unreachable!(), - } - } - fn poll(&mut self) -> Poll, IoError> { - match self.state.take() { - Some(ref state) => match state { - HandlerState::Ready(event) => Ok(Async::Ready(event.clone())), - HandlerState::Err => Err(io::Error::new(io::ErrorKind::Other, "oh noes")), - }, - None => Ok(Async::NotReady), - } - } -} diff --git a/core/src/tests/dummy_muxer.rs b/core/src/tests/dummy_muxer.rs deleted file mode 100644 index eb4bbb16..00000000 --- a/core/src/tests/dummy_muxer.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! `DummyMuxer` is a `StreamMuxer` to be used in tests. It implements a bare-bones -//! version of the trait along with a way to setup the muxer to behave in the -//! desired way when testing other components. - -use futures::prelude::*; -use crate::muxing::StreamMuxer; -use std::io::Error as IoError; - -/// Substream type -#[derive(Debug)] -pub struct DummySubstream {} - -/// OutboundSubstream type -#[derive(Debug)] -pub struct DummyOutboundSubstream {} - -/// Control the muxer state by setting the "connection" state as to set up a mock -/// muxer for higher level components. -#[derive(Debug, PartialEq, Clone)] -pub enum DummyConnectionState { - Pending, // use this to trigger the Async::NotReady code path - Opened, // use this to trigger the Async::Ready(_) code path -} -#[derive(Debug, PartialEq, Clone)] -struct DummyConnection { - state: DummyConnectionState, -} - -/// `DummyMuxer` implements `StreamMuxer` and methods to control its behaviour when used in tests -#[derive(Debug, PartialEq, Clone)] -pub struct DummyMuxer{ - in_connection: DummyConnection, - out_connection: DummyConnection, -} - -impl DummyMuxer { - /// Create a new `DummyMuxer` where the inbound substream is set to `Pending` - /// and the (single) outbound substream to `Pending`. - pub fn new() -> Self { - DummyMuxer { - in_connection: DummyConnection { - state: DummyConnectionState::Pending, - }, - out_connection: DummyConnection { - state: DummyConnectionState::Pending, - }, - } - } - /// Set the muxer state inbound "connection" state - pub fn set_inbound_connection_state(&mut self, state: DummyConnectionState) { - self.in_connection.state = state - } - /// Set the muxer state outbound "connection" state - pub fn set_outbound_connection_state(&mut self, state: DummyConnectionState) { - self.out_connection.state = state - } -} - -impl StreamMuxer for DummyMuxer { - type Substream = DummySubstream; - type OutboundSubstream = DummyOutboundSubstream; - type Error = IoError; - fn poll_inbound(&self) -> Poll { - match self.in_connection.state { - DummyConnectionState::Pending => Ok(Async::NotReady), - DummyConnectionState::Opened => Ok(Async::Ready(Self::Substream {})), - } - } - fn open_outbound(&self) -> Self::OutboundSubstream { - Self::OutboundSubstream {} - } - fn poll_outbound( - &self, - _substream: &mut Self::OutboundSubstream, - ) -> Poll { - match self.out_connection.state { - DummyConnectionState::Pending => Ok(Async::NotReady), - DummyConnectionState::Opened => Ok(Async::Ready(Self::Substream {})), - } - } - fn destroy_outbound(&self, _: Self::OutboundSubstream) {} - fn read_substream(&self, _: &mut Self::Substream, _buf: &mut [u8]) -> Poll { - unreachable!() - } - fn write_substream(&self, _: &mut Self::Substream, _buf: &[u8]) -> Poll { - unreachable!() - } - fn flush_substream(&self, _: &mut Self::Substream) -> Poll<(), IoError> { - unreachable!() - } - fn shutdown_substream(&self, _: &mut Self::Substream) -> Poll<(), IoError> { - unreachable!() - } - fn destroy_substream(&self, _: Self::Substream) {} - fn is_remote_acknowledged(&self) -> bool { true } - fn close(&self) -> Poll<(), IoError> { - Ok(Async::Ready(())) - } - fn flush_all(&self) -> Poll<(), IoError> { - Ok(Async::Ready(())) - } -} diff --git a/core/src/tests/dummy_transport.rs b/core/src/tests/dummy_transport.rs deleted file mode 100644 index 0622ec0e..00000000 --- a/core/src/tests/dummy_transport.rs +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! `DummyTransport` is a `Transport` used in tests. It implements a bare-bones -//! version of the trait along with a way to setup the transport listeners with -//! an initial state to facilitate testing. - -use futures::prelude::*; -use futures::{ - future::{self, FutureResult}, - stream, -}; -use std::io; -use crate::{Multiaddr, PeerId, Transport, transport::{ListenerEvent, TransportError}}; -use crate::tests::dummy_muxer::DummyMuxer; - -#[derive(Debug, PartialEq, Clone)] -pub(crate) enum ListenerState { - Ok(Async>>), - Error, - Events(Vec>) -} - -#[derive(Debug, PartialEq, Clone)] -pub(crate) struct DummyTransport { - /// The current state of Listeners. - listener_state: ListenerState, - /// The next peer returned from dial(). - next_peer_id: Option, - /// When true, all dial attempts return error. - dial_should_fail: bool, -} -impl DummyTransport { - pub(crate) fn new() -> Self { - DummyTransport { - listener_state: ListenerState::Ok(Async::NotReady), - next_peer_id: None, - dial_should_fail: false, - } - } - pub(crate) fn set_initial_listener_state(&mut self, state: ListenerState) { - self.listener_state = state; - } - - pub(crate) fn set_next_peer_id(&mut self, peer_id: &PeerId) { - self.next_peer_id = Some(peer_id.clone()); - } - - pub(crate) fn make_dial_fail(&mut self) { - self.dial_should_fail = true; - } -} -impl Transport for DummyTransport { - type Output = (PeerId, DummyMuxer); - type Error = io::Error; - type Listener = Box, Error=io::Error> + Send>; - type ListenerUpgrade = FutureResult; - type Dial = Box + Send>; - - fn listen_on(self, addr: Multiaddr) -> Result> - where - Self: Sized, - { - match self.listener_state { - ListenerState::Ok(state) => match state { - Async::NotReady => Ok(Box::new(stream::poll_fn(|| Ok(Async::NotReady)))), - Async::Ready(Some(event)) => Ok(Box::new(stream::poll_fn(move || { - Ok(Async::Ready(Some(event.clone().map(future::ok)))) - }))), - Async::Ready(None) => Ok(Box::new(stream::empty())) - }, - ListenerState::Error => Err(TransportError::MultiaddrNotSupported(addr)), - ListenerState::Events(events) => - Ok(Box::new(stream::iter_ok(events.into_iter().map(|e| e.map(future::ok))))) - } - } - - fn dial(self, _addr: Multiaddr) -> Result> - where - Self: Sized, - { - let peer_id = if let Some(peer_id) = self.next_peer_id { - peer_id - } else { - PeerId::random() - }; - - let fut = - if self.dial_should_fail { - let err_string = format!("unreachable host error, peer={:?}", peer_id); - future::err(io::Error::new(io::ErrorKind::Other, err_string)) - } else { - future::ok((peer_id, DummyMuxer::new())) - }; - - Ok(Box::new(fut)) - } -} diff --git a/core/src/tests/mod.rs b/core/src/tests/mod.rs deleted file mode 100644 index 5c86aec1..00000000 --- a/core/src/tests/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -#[cfg(test)] -pub(crate) mod dummy_muxer; - -#[cfg(test)] -pub(crate) mod dummy_transport; - -#[cfg(test)] -pub(crate) mod dummy_handler; diff --git a/core/src/transport/and_then.rs b/core/src/transport/and_then.rs index d4233c44..a2e7ed61 100644 --- a/core/src/transport/and_then.rs +++ b/core/src/transport/and_then.rs @@ -23,9 +23,9 @@ use crate::{ either::EitherError, transport::{Transport, TransportError, ListenerEvent} }; -use futures::{future::Either, prelude::*, try_ready}; +use futures::{future::Either, prelude::*}; use multiaddr::Multiaddr; -use std::error; +use std::{error, pin::Pin, task::Context, task::Poll}; /// See the `Transport::and_then` method. #[derive(Debug, Clone)] @@ -40,15 +40,18 @@ impl AndThen { impl Transport for AndThen where T: Transport, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, C: FnOnce(T::Output, ConnectedPoint) -> F + Clone, - F: IntoFuture, + F: TryFuture + Unpin, F::Error: error::Error, { type Output = O; type Error = EitherError; type Listener = AndThenStream; - type ListenerUpgrade = AndThenFuture; - type Dial = AndThenFuture; + type ListenerUpgrade = AndThenFuture; + type Dial = AndThenFuture; fn listen_on(self, addr: Multiaddr) -> Result> { let listener = self.transport.listen_on(addr).map_err(|err| err.map(EitherError::A))?; @@ -63,7 +66,7 @@ where fn dial(self, addr: Multiaddr) -> Result> { let dialed_fut = self.transport.dial(addr.clone()).map_err(|err| err.map(EitherError::A))?; let future = AndThenFuture { - inner: Either::A(dialed_fut), + inner: Either::Left(dialed_fut), args: Some((self.fun, ConnectedPoint::Dialer { address: addr })) }; Ok(future) @@ -79,19 +82,24 @@ pub struct AndThenStream { fun: TMap } +impl Unpin for AndThenStream { +} + impl Stream for AndThenStream where - TListener: Stream, Error = TTransErr>, - TListUpgr: Future, + TListener: TryStream, Error = TTransErr> + Unpin, + TListUpgr: TryFuture, TMap: FnOnce(TTransOut, ConnectedPoint) -> TMapOut + Clone, - TMapOut: IntoFuture + TMapOut: TryFuture { - type Item = ListenerEvent>; - type Error = EitherError; + type Item = Result< + ListenerEvent>, + EitherError + >; - fn poll(&mut self) -> Poll, Self::Error> { - match self.stream.poll().map_err(EitherError::A)? { - Async::Ready(Some(event)) => { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match TryStream::try_poll_next(Pin::new(&mut self.stream), cx) { + Poll::Ready(Some(Ok(event))) => { let event = match event { ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { let point = ConnectedPoint::Listener { @@ -100,7 +108,7 @@ where }; ListenerEvent::Upgrade { upgrade: AndThenFuture { - inner: Either::A(upgrade), + inner: Either::Left(upgrade), args: Some((self.fun.clone(), point)) }, local_addr, @@ -110,10 +118,11 @@ where ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a) }; - Ok(Async::Ready(Some(event))) + Poll::Ready(Some(Ok(event))) } - Async::Ready(None) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady) + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending } } } @@ -127,28 +136,39 @@ pub struct AndThenFuture { args: Option<(TMap, ConnectedPoint)> } -impl Future for AndThenFuture -where - TFut: Future, - TMap: FnOnce(TFut::Item, ConnectedPoint) -> TMapOut, - TMapOut: IntoFuture -{ - type Item = ::Item; - type Error = EitherError; +impl Unpin for AndThenFuture { +} - fn poll(&mut self) -> Poll { +impl Future for AndThenFuture +where + TFut: TryFuture + Unpin, + TMap: FnOnce(TFut::Ok, ConnectedPoint) -> TMapOut, + TMapOut: TryFuture + Unpin +{ + type Output = Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { loop { - let future = match self.inner { - Either::A(ref mut future) => { - let item = try_ready!(future.poll().map_err(EitherError::A)); + let future = match (*self).inner { + Either::Left(ref mut future) => { + let item = match TryFuture::try_poll(Pin::new(future), cx) { + Poll::Ready(Ok(v)) => v, + Poll::Ready(Err(err)) => return Poll::Ready(Err(EitherError::A(err))), + Poll::Pending => return Poll::Pending, + }; let (f, a) = self.args.take().expect("AndThenFuture has already finished."); - f(item, a).into_future() + f(item, a) + } + Either::Right(ref mut future) => { + return match TryFuture::try_poll(Pin::new(future), cx) { + Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)), + Poll::Ready(Err(err)) => return Poll::Ready(Err(EitherError::B(err))), + Poll::Pending => Poll::Pending, + } } - Either::B(ref mut future) => return future.poll().map_err(EitherError::B) }; - self.inner = Either::B(future); + (*self).inner = Either::Right(future); } } } - diff --git a/core/src/transport/boxed.rs b/core/src/transport/boxed.rs index 73589423..3d7b95b3 100644 --- a/core/src/transport/boxed.rs +++ b/core/src/transport/boxed.rs @@ -21,7 +21,7 @@ use crate::transport::{ListenerEvent, Transport, TransportError}; use futures::prelude::*; use multiaddr::Multiaddr; -use std::{error, fmt, sync::Arc}; +use std::{error, fmt, pin::Pin, sync::Arc}; /// See the `Transport::boxed` method. #[inline] @@ -37,9 +37,9 @@ where } } -pub type Dial = Box + Send>; -pub type Listener = Box>, Error = E> + Send>; -pub type ListenerUpgrade = Box + Send>; +pub type Dial = Pin> + Send>>; +pub type Listener = Pin>, E>> + Send>>; +pub type ListenerUpgrade = Pin> + Send>>; trait Abstract { fn listen_on(&self, addr: Multiaddr) -> Result, TransportError>; @@ -56,15 +56,15 @@ where { fn listen_on(&self, addr: Multiaddr) -> Result, TransportError> { let listener = Transport::listen_on(self.clone(), addr)?; - let fut = listener.map(|event| event.map(|upgrade| { - Box::new(upgrade) as ListenerUpgrade + let fut = listener.map_ok(|event| event.map(|upgrade| { + Box::pin(upgrade) as ListenerUpgrade })); - Ok(Box::new(fut) as Box<_>) + Ok(Box::pin(fut)) } fn dial(&self, addr: Multiaddr) -> Result, TransportError> { let fut = Transport::dial(self.clone(), addr)?; - Ok(Box::new(fut) as Box<_>) + Ok(Box::pin(fut) as Dial<_, _>) } } diff --git a/core/src/transport/choice.rs b/core/src/transport/choice.rs index c6593912..c3bfc15d 100644 --- a/core/src/transport/choice.rs +++ b/core/src/transport/choice.rs @@ -35,7 +35,13 @@ impl OrTransport { impl Transport for OrTransport where B: Transport, + B::Dial: Unpin, + B::Listener: Unpin, + B::ListenerUpgrade: Unpin, A: Transport, + A::Dial: Unpin, + A::Listener: Unpin, + A::ListenerUpgrade: Unpin, { type Output = EitherOutput; type Error = EitherError; diff --git a/core/src/transport/dummy.rs b/core/src/transport/dummy.rs index 4d478016..f3256b27 100644 --- a/core/src/transport/dummy.rs +++ b/core/src/transport/dummy.rs @@ -20,7 +20,8 @@ use crate::transport::{Transport, TransportError, ListenerEvent}; use crate::Multiaddr; -use std::{fmt, io, marker::PhantomData}; +use futures::{prelude::*, task::Context, task::Poll}; +use std::{fmt, io, marker::PhantomData, pin::Pin}; /// Implementation of `Transport` that doesn't support any multiaddr. /// @@ -55,9 +56,9 @@ impl Clone for DummyTransport { impl Transport for DummyTransport { type Output = TOut; type Error = io::Error; - type Listener = futures::stream::Empty, io::Error>; - type ListenerUpgrade = futures::future::Empty; - type Dial = futures::future::Empty; + type Listener = futures::stream::Pending, io::Error>>; + type ListenerUpgrade = futures::future::Pending>; + type Dial = futures::future::Pending>; fn listen_on(self, addr: Multiaddr) -> Result> { Err(TransportError::MultiaddrNotSupported(addr)) @@ -68,7 +69,7 @@ impl Transport for DummyTransport { } } -/// Implementation of `Read` and `Write`. Not meant to be instanciated. +/// Implementation of `AsyncRead` and `AsyncWrite`. Not meant to be instanciated. pub struct DummyStream(()); impl fmt::Debug for DummyStream { @@ -77,30 +78,30 @@ impl fmt::Debug for DummyStream { } } -impl io::Read for DummyStream { - fn read(&mut self, _: &mut [u8]) -> io::Result { - Err(io::ErrorKind::Other.into()) +impl AsyncRead for DummyStream { + fn poll_read(self: Pin<&mut Self>, _: &mut Context, _: &mut [u8]) + -> Poll> + { + Poll::Ready(Err(io::ErrorKind::Other.into())) } } -impl io::Write for DummyStream { - fn write(&mut self, _: &[u8]) -> io::Result { - Err(io::ErrorKind::Other.into()) +impl AsyncWrite for DummyStream { + fn poll_write(self: Pin<&mut Self>, _: &mut Context, _: &[u8]) + -> Poll> + { + Poll::Ready(Err(io::ErrorKind::Other.into())) } - fn flush(&mut self) -> io::Result<()> { - Err(io::ErrorKind::Other.into()) + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) + -> Poll> + { + Poll::Ready(Err(io::ErrorKind::Other.into())) } -} -impl tokio_io::AsyncRead for DummyStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl tokio_io::AsyncWrite for DummyStream { - fn shutdown(&mut self) -> futures::Poll<(), io::Error> { - Err(io::ErrorKind::Other.into()) + fn poll_close(self: Pin<&mut Self>, _: &mut Context) + -> Poll> + { + Poll::Ready(Err(io::ErrorKind::Other.into())) } } diff --git a/core/src/transport/map.rs b/core/src/transport/map.rs index 53f49b75..7652e892 100644 --- a/core/src/transport/map.rs +++ b/core/src/transport/map.rs @@ -22,8 +22,9 @@ use crate::{ ConnectedPoint, transport::{Transport, TransportError, ListenerEvent} }; -use futures::{prelude::*, try_ready}; +use futures::prelude::*; use multiaddr::Multiaddr; +use std::{pin::Pin, task::Context, task::Poll}; /// See `Transport::map`. #[derive(Debug, Copy, Clone)] @@ -38,6 +39,9 @@ impl Map { impl Transport for Map where T: Transport, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, F: FnOnce(T::Output, ConnectedPoint) -> D + Clone { type Output = D; @@ -64,18 +68,20 @@ where #[derive(Clone, Debug)] pub struct MapStream { stream: T, fun: F } +impl Unpin for MapStream { +} + impl Stream for MapStream where - T: Stream>, - X: Future, + T: TryStream> + Unpin, + X: TryFuture, F: FnOnce(A, ConnectedPoint) -> B + Clone { - type Item = ListenerEvent>; - type Error = T::Error; + type Item = Result>, T::Error>; - fn poll(&mut self) -> Poll, Self::Error> { - match self.stream.poll()? { - Async::Ready(Some(event)) => { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match TryStream::try_poll_next(Pin::new(&mut self.stream), cx) { + Poll::Ready(Some(Ok(event))) => { let event = match event { ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { let point = ConnectedPoint::Listener { @@ -94,10 +100,11 @@ where ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a) }; - Ok(Async::Ready(Some(event))) + Poll::Ready(Some(Ok(event))) } - Async::Ready(None) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady) + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending } } } @@ -111,18 +118,24 @@ pub struct MapFuture { args: Option<(F, ConnectedPoint)> } +impl Unpin for MapFuture { +} + impl Future for MapFuture where - T: Future, + T: TryFuture + Unpin, F: FnOnce(A, ConnectedPoint) -> B { - type Item = B; - type Error = T::Error; + type Output = Result; - fn poll(&mut self) -> Poll { - let item = try_ready!(self.inner.poll()); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let item = match TryFuture::try_poll(Pin::new(&mut self.inner), cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(v)) => v, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + }; let (f, a) = self.args.take().expect("MapFuture has already finished."); - Ok(Async::Ready(f(item, a))) + Poll::Ready(Ok(f(item, a))) } } diff --git a/core/src/transport/map_err.rs b/core/src/transport/map_err.rs index 0642c681..36f48209 100644 --- a/core/src/transport/map_err.rs +++ b/core/src/transport/map_err.rs @@ -21,7 +21,7 @@ use crate::transport::{Transport, TransportError, ListenerEvent}; use futures::prelude::*; use multiaddr::Multiaddr; -use std::error; +use std::{error, pin::Pin, task::Context, task::Poll}; /// See `Transport::map_err`. #[derive(Debug, Copy, Clone)] @@ -40,6 +40,9 @@ impl MapErr { impl Transport for MapErr where T: Transport, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, F: FnOnce(T::Error) -> TErr + Clone, TErr: error::Error, { @@ -72,29 +75,34 @@ pub struct MapErrListener { map: F, } +impl Unpin for MapErrListener + where T: Transport +{ +} + impl Stream for MapErrListener where T: Transport, + T::Listener: Unpin, F: FnOnce(T::Error) -> TErr + Clone, TErr: error::Error, { - type Item = ListenerEvent>; - type Error = TErr; + type Item = Result>, TErr>; - fn poll(&mut self) -> Poll, Self::Error> { - match self.inner.poll() { - Ok(Async::Ready(Some(event))) => { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match TryStream::try_poll_next(Pin::new(&mut self.inner), cx) { + Poll::Ready(Some(Ok(event))) => { let event = event.map(move |value| { MapErrListenerUpgrade { inner: value, map: Some(self.map.clone()) } }); - Ok(Async::Ready(Some(event))) + Poll::Ready(Some(Ok(event))) } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err((self.map.clone())(err)), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err((self.map.clone())(err)))), } } } @@ -105,20 +113,25 @@ pub struct MapErrListenerUpgrade { map: Option, } +impl Unpin for MapErrListenerUpgrade + where T: Transport +{ +} + impl Future for MapErrListenerUpgrade where T: Transport, + T::ListenerUpgrade: Unpin, F: FnOnce(T::Error) -> TErr, { - type Item = T::Output; - type Error = TErr; + type Output = Result; - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(value)) => Ok(Async::Ready(value)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match Future::poll(Pin::new(&mut self.inner), cx) { + Poll::Ready(Ok(value)) => Poll::Ready(Ok(value)), + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { let map = self.map.take().expect("poll() called again after error"); - Err(map(err)) + Poll::Ready(Err(map(err))) } } } @@ -130,23 +143,26 @@ pub struct MapErrDial { map: Option, } +impl Unpin for MapErrDial + where T: Transport +{ +} + impl Future for MapErrDial where T: Transport, + T::Dial: Unpin, F: FnOnce(T::Error) -> TErr, { - type Item = T::Output; - type Error = TErr; + type Output = Result; - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(value)) => { - Ok(Async::Ready(value)) - }, - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match Future::poll(Pin::new(&mut self.inner), cx) { + Poll::Ready(Ok(value)) => Poll::Ready(Ok(value)), + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { let map = self.map.take().expect("poll() called again after error"); - Err(map(err)) + Poll::Ready(Err(map(err))) } } } diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index 1b399509..e53a1f2b 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -21,12 +21,12 @@ use crate::{Transport, transport::{TransportError, ListenerEvent}}; use bytes::{Bytes, IntoBuf}; use fnv::FnvHashMap; -use futures::{future::{self, FutureResult}, prelude::*, sync::mpsc, try_ready}; +use futures::{future::{self, Ready}, prelude::*, channel::mpsc, task::Context, task::Poll}; use lazy_static::lazy_static; use multiaddr::{Protocol, Multiaddr}; use parking_lot::Mutex; use rw_stream_sink::RwStreamSink; -use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64}; +use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; lazy_static! { static ref HUB: Mutex>>> = @@ -45,26 +45,24 @@ pub struct DialFuture { } impl Future for DialFuture { - type Item = Channel; - type Error = MemoryTransportError; + type Output = Result, MemoryTransportError>; - fn poll(&mut self) -> Poll { - if let Some(c) = self.channel_to_send.take() { - match self.sender.start_send(c) { - Err(_) => return Err(MemoryTransportError::Unreachable), - Ok(AsyncSink::NotReady(t)) => { - self.channel_to_send = Some(t); - return Ok(Async::NotReady) - }, - _ => (), - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.sender.poll_ready(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)), } - match self.sender.close() { - Err(_) => Err(MemoryTransportError::Unreachable), - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(_)) => Ok(Async::Ready(self.channel_to_return.take() - .expect("Future should not be polled again once complete"))), + + let channel_to_send = self.channel_to_send.take() + .expect("Future should not be polled again once complete"); + match self.sender.start_send(channel_to_send) { + Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)), + Ok(()) => {} } + + Poll::Ready(Ok(self.channel_to_return.take() + .expect("Future should not be polled again once complete"))) } } @@ -72,7 +70,7 @@ impl Transport for MemoryTransport { type Output = Channel; type Error = MemoryTransportError; type Listener = Listener; - type ListenerUpgrade = FutureResult; + type ListenerUpgrade = Ready>; type Dial = DialFuture; fn listen_on(self, addr: Multiaddr) -> Result> { @@ -176,26 +174,27 @@ pub struct Listener { } impl Stream for Listener { - type Item = ListenerEvent, MemoryTransportError>>; - type Error = MemoryTransportError; + type Item = Result, MemoryTransportError>>>, MemoryTransportError>; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.tell_listen_addr { self.tell_listen_addr = false; - return Ok(Async::Ready(Some(ListenerEvent::NewAddress(self.addr.clone())))) + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))) } - let channel = try_ready!(Ok(self.receiver.poll() - .expect("Life listeners always have a sender."))); - let channel = match channel { - Some(c) => c, - None => return Ok(Async::Ready(None)) + + let channel = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => panic!("Alive listeners always have a sender."), + Poll::Ready(Some(v)) => v, }; + let event = ListenerEvent::Upgrade { - upgrade: future::ok(channel), + upgrade: future::ready(Ok(channel)), local_addr: self.addr.clone(), remote_addr: Protocol::Memory(self.port.get()).into() }; - Ok(Async::Ready(Some(event))) + + Poll::Ready(Some(Ok(event))) } } @@ -236,33 +235,39 @@ pub struct Chan { outgoing: mpsc::Sender, } -impl Stream for Chan { - type Item = T; - type Error = io::Error; +impl Unpin for Chan { +} - #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - self.incoming.poll().map_err(|()| io::ErrorKind::BrokenPipe.into()) +impl Stream for Chan { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match Stream::poll_next(Pin::new(&mut self.incoming), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(Some(Err(io::ErrorKind::BrokenPipe.into()))), + Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))), + } } } -impl Sink for Chan { - type SinkItem = T; - type SinkError = io::Error; +impl Sink for Chan { + type Error = io::Error; - #[inline] - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.outgoing.poll_ready(cx) + .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into())) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into()) } - #[inline] - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.outgoing.poll_complete().map_err(|_| io::ErrorKind::BrokenPipe.into()) + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } - #[inline] - fn close(&mut self) -> Poll<(), Self::SinkError> { - self.outgoing.close().map_err(|_| io::ErrorKind::BrokenPipe.into()) + fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } } diff --git a/core/src/transport/mod.rs b/core/src/transport/mod.rs index 6cc9dc7b..4a0a384c 100644 --- a/core/src/transport/mod.rs +++ b/core/src/transport/mod.rs @@ -91,7 +91,7 @@ pub trait Transport { /// transport stack. The item must be a [`ListenerUpgrade`](Transport::ListenerUpgrade) future /// that resolves to an [`Output`](Transport::Output) value once all protocol upgrades /// have been applied. - type Listener: Stream, Error = Self::Error>; + type Listener: TryStream, Error = Self::Error>; /// A pending [`Output`](Transport::Output) for an inbound connection, /// obtained from the [`Listener`](Transport::Listener) stream. @@ -102,11 +102,11 @@ pub trait Transport { /// connection, hence further connection setup proceeds asynchronously. /// Once a `ListenerUpgrade` future resolves it yields the [`Output`](Transport::Output) /// of the connection setup process. - type ListenerUpgrade: Future; + type ListenerUpgrade: Future>; /// A pending [`Output`](Transport::Output) for an outbound connection, /// obtained from [dialing](Transport::dial). - type Dial: Future; + type Dial: Future>; /// Listens on the given [`Multiaddr`], producing a stream of pending, inbound connections /// and addresses this transport is listening on (cf. [`ListenerEvent`]). @@ -175,8 +175,8 @@ pub trait Transport { where Self: Sized, C: FnOnce(Self::Output, ConnectedPoint) -> F + Clone, - F: IntoFuture, - ::Error: Error + 'static + F: TryFuture, + ::Error: Error + 'static { and_then::AndThen::new(self, f) } diff --git a/core/src/transport/timeout.rs b/core/src/transport/timeout.rs index 8a2bde99..c254d241 100644 --- a/core/src/transport/timeout.rs +++ b/core/src/transport/timeout.rs @@ -25,11 +25,9 @@ // TODO: add example use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}}; -use futures::{try_ready, Async, Future, Poll, Stream}; -use log::debug; -use std::{error, fmt, time::Duration}; -use wasm_timer::Timeout; -use wasm_timer::timeout::Error as TimeoutError; +use futures::prelude::*; +use futures_timer::Delay; +use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration}; /// A `TransportTimeout` is a `Transport` that wraps another `Transport` and adds /// timeouts to all inbound and outbound connection attempts. @@ -76,12 +74,15 @@ impl Transport for TransportTimeout where InnerTrans: Transport, InnerTrans::Error: 'static, + InnerTrans::Dial: Unpin, + InnerTrans::Listener: Unpin, + InnerTrans::ListenerUpgrade: Unpin, { type Output = InnerTrans::Output; type Error = TransportTimeoutError; type Listener = TimeoutListener; - type ListenerUpgrade = TokioTimerMapErr>; - type Dial = TokioTimerMapErr>; + type ListenerUpgrade = Timeout; + type Dial = Timeout; fn listen_on(self, addr: Multiaddr) -> Result> { let listener = self.inner.listen_on(addr) @@ -98,8 +99,9 @@ where fn dial(self, addr: Multiaddr) -> Result> { let dial = self.inner.dial(addr) .map_err(|err| err.map(TransportTimeoutError::Other))?; - Ok(TokioTimerMapErr { - inner: Timeout::new(dial, self.outgoing_timeout), + Ok(Timeout { + inner: dial, + timer: Delay::new(self.outgoing_timeout), }) } } @@ -113,21 +115,26 @@ pub struct TimeoutListener { impl Stream for TimeoutListener where - InnerStream: Stream> + InnerStream: TryStream> + Unpin { - type Item = ListenerEvent>>; - type Error = TransportTimeoutError; + type Item = Result>, TransportTimeoutError>; - fn poll(&mut self) -> Poll, Self::Error> { - let poll_out = try_ready!(self.inner.poll().map_err(TransportTimeoutError::Other)); - if let Some(event) = poll_out { - let event = event.map(move |inner_fut| { - TokioTimerMapErr { inner: Timeout::new(inner_fut, self.timeout) } - }); - Ok(Async::Ready(Some(event))) - } else { - Ok(Async::Ready(None)) - } + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let poll_out = match TryStream::try_poll_next(Pin::new(&mut self.inner), cx) { + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))), + Poll::Ready(Some(Ok(v))) => v, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }; + + let event = poll_out.map(move |inner_fut| { + Timeout { + inner: inner_fut, + timer: Delay::new(self.timeout), + } + }); + + Poll::Ready(Some(Ok(event))) } } @@ -136,40 +143,44 @@ where // TODO: can be replaced with `impl Future` once `impl Trait` are fully stable in Rust // (https://github.com/rust-lang/rust/issues/34511) #[must_use = "futures do nothing unless polled"] -pub struct TokioTimerMapErr { +pub struct Timeout { inner: InnerFut, + timer: Delay, } -impl Future for TokioTimerMapErr +impl Future for Timeout where - InnerFut: Future>, + InnerFut: TryFuture + Unpin, { - type Item = InnerFut::Item; - type Error = TransportTimeoutError; + type Output = Result>; - fn poll(&mut self) -> Poll { - self.inner.poll().map_err(|err: TimeoutError| { - if err.is_inner() { - TransportTimeoutError::Other(err.into_inner().expect("ensured by is_inner()")) - } else if err.is_elapsed() { - debug!("timeout elapsed for connection"); - TransportTimeoutError::Timeout - } else { - assert!(err.is_timer()); - debug!("tokio timer error in timeout wrapper"); - TransportTimeoutError::TimerError - } - }) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // It is debatable whether we should poll the inner future first or the timer first. + // For example, if you start dialing with a timeout of 10 seconds, then after 15 seconds + // the dialing succeeds on the wire, then after 20 seconds you poll, then depending on + // which gets polled first, the outcome will be success or failure. + + match TryFuture::try_poll(Pin::new(&mut self.inner), cx) { + Poll::Pending => {}, + Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)), + Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))), + } + + match TryFuture::try_poll(Pin::new(&mut self.timer), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(Err(TransportTimeoutError::Timeout)), + Poll::Ready(Err(err)) => Poll::Ready(Err(TransportTimeoutError::TimerError(err))), + } } } /// Error that can be produced by the `TransportTimeout` layer. -#[derive(Debug, Copy, Clone)] +#[derive(Debug)] pub enum TransportTimeoutError { /// The transport timed out. Timeout, /// An error happened in the timer. - TimerError, + TimerError(io::Error), /// Other kind of error. Other(TErr), } @@ -180,7 +191,7 @@ where TErr: fmt::Display, fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TransportTimeoutError::Timeout => write!(f, "Timeout has been reached"), - TransportTimeoutError::TimerError => write!(f, "Error in the timer"), + TransportTimeoutError::TimerError(err) => write!(f, "Error in the timer: {}", err), TransportTimeoutError::Other(err) => write!(f, "{}", err), } } @@ -192,7 +203,7 @@ where TErr: error::Error + 'static, fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { TransportTimeoutError::Timeout => None, - TransportTimeoutError::TimerError => None, + TransportTimeoutError::TimerError(err) => Some(err), TransportTimeoutError::Other(err) => Some(err), } } diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index 4a4535ff..289bbdbc 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -41,10 +41,9 @@ use crate::{ InboundUpgradeApply } }; -use futures::{future, prelude::*, try_ready}; +use futures::{prelude::*, ready}; use multiaddr::Multiaddr; -use std::{error::Error, fmt}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{error::Error, fmt, pin::Pin, task::Context, task::Poll}; /// A `Builder` facilitates upgrading of a [`Transport`] for use with /// a [`Network`]. @@ -98,9 +97,12 @@ where AndThen Authenticate + Clone> > where T: Transport, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, I: ConnectionInfo, - C: AsyncRead + AsyncWrite, - D: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, U: OutboundUpgrade + Clone, E: Error + 'static, @@ -126,8 +128,11 @@ where pub fn apply(self, upgrade: U) -> Builder> where T: Transport, - C: AsyncRead + AsyncWrite, - D: AsyncRead + AsyncWrite, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, I: ConnectionInfo, U: InboundUpgrade, U: OutboundUpgrade + Clone, @@ -151,7 +156,10 @@ where -> AndThen Multiplex + Clone> where T: Transport, - C: AsyncRead + AsyncWrite, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, + C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, I: ConnectionInfo, U: InboundUpgrade, @@ -171,7 +179,7 @@ where /// Configured through [`Builder::authenticate`]. pub struct Authenticate where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade + OutboundUpgrade { inner: EitherUpgrade @@ -179,17 +187,16 @@ where impl Future for Authenticate where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade + OutboundUpgrade>::Output, Error = >::Error > { - type Item = as Future>::Item; - type Error = as Future>::Error; + type Output = as Future>::Output; - fn poll(&mut self) -> Poll { - self.inner.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Future::poll(Pin::new(&mut self.inner), cx) } } @@ -199,7 +206,7 @@ where /// Configured through [`Builder::multiplex`]. pub struct Multiplex where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade + OutboundUpgrade, { info: Option, @@ -208,20 +215,29 @@ where impl Future for Multiplex where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, U: OutboundUpgrade { - type Item = (I, M); - type Error = UpgradeError; + type Output = Result<(I, M), UpgradeError>; - fn poll(&mut self) -> Poll { - let m = try_ready!(self.upgrade.poll()); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let m = match ready!(Future::poll(Pin::new(&mut self.upgrade), cx)) { + Ok(m) => m, + Err(err) => return Poll::Ready(Err(err)), + }; let i = self.info.take().expect("Multiplex future polled after completion."); - Ok(Async::Ready((i, m))) + Poll::Ready(Ok((i, m))) } } +impl Unpin for Multiplex +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade + OutboundUpgrade, +{ +} + /// An inbound or outbound upgrade. type EitherUpgrade = future::Either, OutboundUpgradeApply>; @@ -240,8 +256,11 @@ impl Upgrade { impl Transport for Upgrade where T: Transport, + T::Dial: Unpin, + T::Listener: Unpin, + T::ListenerUpgrade: Unpin, T::Error: 'static, - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, U: OutboundUpgrade + Clone, E: Error + 'static @@ -257,7 +276,7 @@ where .map_err(|err| err.map(TransportUpgradeError::Transport))?; Ok(DialUpgradeFuture { future, - upgrade: future::Either::A(Some(self.upgrade)) + upgrade: future::Either::Left(Some(self.upgrade)) }) } @@ -310,7 +329,7 @@ where pub struct DialUpgradeFuture where U: OutboundUpgrade, - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, { future: F, upgrade: future::Either, (Option, OutboundUpgradeApply)> @@ -318,32 +337,48 @@ where impl Future for DialUpgradeFuture where - F: Future, - C: AsyncRead + AsyncWrite, + F: TryFuture + Unpin, + C: AsyncRead + AsyncWrite + Unpin, U: OutboundUpgrade, U::Error: Error { - type Item = (I, D); - type Error = TransportUpgradeError; + type Output = Result<(I, D), TransportUpgradeError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // We use a `this` variable because the compiler can't mutably borrow multiple times + // accross a `Deref`. + let this = &mut *self; - fn poll(&mut self) -> Poll { loop { - self.upgrade = match self.upgrade { - future::Either::A(ref mut up) => { - let (i, c) = try_ready!(self.future.poll().map_err(TransportUpgradeError::Transport)); - let u = up.take().expect("DialUpgradeFuture is constructed with Either::A(Some)."); - future::Either::B((Some(i), apply_outbound(c, u))) + this.upgrade = match this.upgrade { + future::Either::Left(ref mut up) => { + let (i, c) = match ready!(TryFuture::try_poll(Pin::new(&mut this.future), cx).map_err(TransportUpgradeError::Transport)) { + Ok(v) => v, + Err(err) => return Poll::Ready(Err(err)), + }; + let u = up.take().expect("DialUpgradeFuture is constructed with Either::Left(Some)."); + future::Either::Right((Some(i), apply_outbound(c, u))) } - future::Either::B((ref mut i, ref mut up)) => { - let d = try_ready!(up.poll().map_err(TransportUpgradeError::Upgrade)); + future::Either::Right((ref mut i, ref mut up)) => { + let d = match ready!(Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { + Ok(d) => d, + Err(err) => return Poll::Ready(Err(err)), + }; let i = i.take().expect("DialUpgradeFuture polled after completion."); - return Ok(Async::Ready((i, d))) + return Poll::Ready(Ok((i, d))) } } } } } +impl Unpin for DialUpgradeFuture +where + U: OutboundUpgrade, + C: AsyncRead + AsyncWrite + Unpin, +{ +} + /// The [`Transport::Listener`] stream of an [`Upgrade`]d transport. pub struct ListenerStream { stream: S, @@ -352,34 +387,39 @@ pub struct ListenerStream { impl Stream for ListenerStream where - S: Stream>, - F: Future, - C: AsyncRead + AsyncWrite, + S: TryStream> + Unpin, + F: TryFuture, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade + Clone { - type Item = ListenerEvent>; - type Error = TransportUpgradeError; + type Item = Result>, TransportUpgradeError>; - fn poll(&mut self) -> Poll, Self::Error> { - match try_ready!(self.stream.poll().map_err(TransportUpgradeError::Transport)) { - Some(event) => { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match ready!(TryStream::try_poll_next(Pin::new(&mut self.stream), cx)) { + Some(Ok(event)) => { let event = event.map(move |future| { ListenerUpgradeFuture { future, - upgrade: future::Either::A(Some(self.upgrade.clone())) + upgrade: future::Either::Left(Some(self.upgrade.clone())) } }); - Ok(Async::Ready(Some(event))) + Poll::Ready(Some(Ok(event))) } - None => Ok(Async::Ready(None)) + Some(Err(err)) => { + Poll::Ready(Some(Err(TransportUpgradeError::Transport(err)))) + } + None => Poll::Ready(None) } } } +impl Unpin for ListenerStream { +} + /// The [`Transport::ListenerUpgrade`] future of an [`Upgrade`]d transport. pub struct ListenerUpgradeFuture where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade { future: F, @@ -388,29 +428,44 @@ where impl Future for ListenerUpgradeFuture where - F: Future, - C: AsyncRead + AsyncWrite, + F: TryFuture + Unpin, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, U::Error: Error { - type Item = (I, D); - type Error = TransportUpgradeError; + type Output = Result<(I, D), TransportUpgradeError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // We use a `this` variable because the compiler can't mutably borrow multiple times + // accross a `Deref`. + let this = &mut *self; - fn poll(&mut self) -> Poll { loop { - self.upgrade = match self.upgrade { - future::Either::A(ref mut up) => { - let (i, c) = try_ready!(self.future.poll().map_err(TransportUpgradeError::Transport)); - let u = up.take().expect("ListenerUpgradeFuture is constructed with Either::A(Some)."); - future::Either::B((Some(i), apply_inbound(c, u))) + this.upgrade = match this.upgrade { + future::Either::Left(ref mut up) => { + let (i, c) = match ready!(TryFuture::try_poll(Pin::new(&mut this.future), cx).map_err(TransportUpgradeError::Transport)) { + Ok(v) => v, + Err(err) => return Poll::Ready(Err(err)) + }; + let u = up.take().expect("ListenerUpgradeFuture is constructed with Either::Left(Some)."); + future::Either::Right((Some(i), apply_inbound(c, u))) } - future::Either::B((ref mut i, ref mut up)) => { - let d = try_ready!(up.poll().map_err(TransportUpgradeError::Upgrade)); + future::Either::Right((ref mut i, ref mut up)) => { + let d = match ready!(TryFuture::try_poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { + Ok(v) => v, + Err(err) => return Poll::Ready(Err(err)) + }; let i = i.take().expect("ListenerUpgradeFuture polled after completion."); - return Ok(Async::Ready((i, d))) + return Poll::Ready(Ok((i, d))) } } } } } +impl Unpin for ListenerUpgradeFuture +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade +{ +} diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 787ec4c4..c9e1b80e 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -21,34 +21,33 @@ use crate::ConnectedPoint; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError}; use crate::upgrade::{ProtocolName, NegotiatedComplete}; -use futures::{future::Either, prelude::*}; +use futures::{future::Either, prelude::*, compat::Compat, compat::Compat01As03, compat::Future01CompatExt}; use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; -use std::{iter, mem}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{iter, mem, pin::Pin, task::Context, task::Poll}; /// Applies an upgrade to the inbound and outbound direction of a connection or substream. pub fn apply(conn: C, up: U, cp: ConnectedPoint) -> Either, OutboundUpgradeApply> where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade + OutboundUpgrade, { if cp.is_listener() { - Either::A(apply_inbound(conn, up)) + Either::Left(apply_inbound(conn, up)) } else { - Either::B(apply_outbound(conn, up)) + Either::Right(apply_outbound(conn, up)) } } /// Tries to perform an upgrade on an inbound connection or substream. pub fn apply_inbound(conn: C, up: U) -> InboundUpgradeApply where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, { let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); - let future = multistream_select::listener_select_proto(conn, iter); + let future = multistream_select::listener_select_proto(Compat::new(conn), iter).compat(); InboundUpgradeApply { inner: InboundUpgradeApplyState::Init { future, upgrade: up } } @@ -57,11 +56,11 @@ where /// Tries to perform an upgrade on an outbound connection or substream. pub fn apply_outbound(conn: C, up: U) -> OutboundUpgradeApply where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: OutboundUpgrade { let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); - let future = multistream_select::dialer_select_proto(conn, iter); + let future = multistream_select::dialer_select_proto(Compat::new(conn), iter).compat(); OutboundUpgradeApply { inner: OutboundUpgradeApplyState::Init { future, upgrade: up } } @@ -70,7 +69,7 @@ where /// Future returned by `apply_inbound`. Drives the upgrade process. pub struct InboundUpgradeApply where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade { inner: InboundUpgradeApplyState @@ -78,11 +77,11 @@ where enum InboundUpgradeApplyState where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, { Init { - future: ListenerSelectFuture>, + future: Compat01As03, NameWrap>>, upgrade: U, }, Upgrade { @@ -91,42 +90,49 @@ where Undefined } -impl Future for InboundUpgradeApply +impl Unpin for InboundUpgradeApply where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, { - type Item = U::Output; - type Error = UpgradeError; +} - fn poll(&mut self) -> Poll { +impl Future for InboundUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade, + U::Future: Unpin, +{ + type Output = Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { loop { match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) { InboundUpgradeApplyState::Init { mut future, upgrade } => { - let (info, io) = match future.poll()? { - Async::Ready(x) => x, - Async::NotReady => { + let (info, io) = match Future::poll(Pin::new(&mut future), cx)? { + Poll::Ready(x) => x, + Poll::Pending => { self.inner = InboundUpgradeApplyState::Init { future, upgrade }; - return Ok(Async::NotReady) + return Poll::Pending } }; self.inner = InboundUpgradeApplyState::Upgrade { - future: upgrade.upgrade_inbound(io, info.0) + future: upgrade.upgrade_inbound(Compat01As03::new(io), info.0) }; } InboundUpgradeApplyState::Upgrade { mut future } => { - match future.poll() { - Ok(Async::NotReady) => { + match Future::poll(Pin::new(&mut future), cx) { + Poll::Pending => { self.inner = InboundUpgradeApplyState::Upgrade { future }; - return Ok(Async::NotReady) + return Poll::Pending } - Ok(Async::Ready(x)) => { + Poll::Ready(Ok(x)) => { debug!("Successfully applied negotiated protocol"); - return Ok(Async::Ready(x)) + return Poll::Ready(Ok(x)) } - Err(e) => { + Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); - return Err(UpgradeError::Apply(e)) + return Poll::Ready(Err(UpgradeError::Apply(e))) } } } @@ -140,7 +146,7 @@ where /// Future returned by `apply_outbound`. Drives the upgrade process. pub struct OutboundUpgradeApply where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: OutboundUpgrade { inner: OutboundUpgradeApplyState @@ -148,15 +154,15 @@ where enum OutboundUpgradeApplyState where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, U: OutboundUpgrade { Init { - future: DialerSelectFuture::IntoIter>>, + future: Compat01As03, NameWrapIter<::IntoIter>>>, upgrade: U }, AwaitNegotiated { - io: NegotiatedComplete, + io: Compat01As03>>, upgrade: U, protocol: U::Info }, @@ -166,58 +172,65 @@ where Undefined } +impl Unpin for OutboundUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: OutboundUpgrade, +{ +} + impl Future for OutboundUpgradeApply where - C: AsyncRead + AsyncWrite, - U: OutboundUpgrade + C: AsyncRead + AsyncWrite + Unpin, + U: OutboundUpgrade, + U::Future: Unpin, { - type Item = U::Output; - type Error = UpgradeError; + type Output = Result>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { loop { match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) { OutboundUpgradeApplyState::Init { mut future, upgrade } => { - let (info, connection) = match future.poll()? { - Async::Ready(x) => x, - Async::NotReady => { + let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? { + Poll::Ready(x) => x, + Poll::Pending => { self.inner = OutboundUpgradeApplyState::Init { future, upgrade }; - return Ok(Async::NotReady) + return Poll::Pending } }; self.inner = OutboundUpgradeApplyState::AwaitNegotiated { - io: connection.complete(), + io: Compat01As03::new(connection.complete()), protocol: info.0, upgrade }; } OutboundUpgradeApplyState::AwaitNegotiated { mut io, protocol, upgrade } => { - let io = match io.poll()? { - Async::NotReady => { + let io = match Future::poll(Pin::new(&mut io), cx)? { + Poll::Pending => { self.inner = OutboundUpgradeApplyState::AwaitNegotiated { io, protocol, upgrade }; - return Ok(Async::NotReady) + return Poll::Pending } - Async::Ready(io) => io + Poll::Ready(io) => io }; self.inner = OutboundUpgradeApplyState::Upgrade { - future: upgrade.upgrade_outbound(io, protocol) + future: upgrade.upgrade_outbound(Compat01As03::new(io), protocol) }; } OutboundUpgradeApplyState::Upgrade { mut future } => { - match future.poll() { - Ok(Async::NotReady) => { + match Future::poll(Pin::new(&mut future), cx) { + Poll::Pending => { self.inner = OutboundUpgradeApplyState::Upgrade { future }; - return Ok(Async::NotReady) + return Poll::Pending } - Ok(Async::Ready(x)) => { + Poll::Ready(Ok(x)) => { debug!("Successfully applied negotiated protocol"); - return Ok(Async::Ready(x)) + return Poll::Ready(Ok(x)) } - Err(e) => { + Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); - return Err(UpgradeError::Apply(e)) + return Poll::Ready(Err(UpgradeError::Apply(e))); } } } diff --git a/core/src/upgrade/denied.rs b/core/src/upgrade/denied.rs index 9dec47ee..276d8782 100644 --- a/core/src/upgrade/denied.rs +++ b/core/src/upgrade/denied.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::Negotiated; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use futures::future; -use multistream_select::Negotiated; use std::iter; use void::Void; @@ -41,20 +41,19 @@ impl UpgradeInfo for DeniedUpgrade { impl InboundUpgrade for DeniedUpgrade { type Output = Void; type Error = Void; - type Future = future::Empty; + type Future = future::Pending>; fn upgrade_inbound(self, _: Negotiated, _: Self::Info) -> Self::Future { - future::empty() + future::pending() } } impl OutboundUpgrade for DeniedUpgrade { type Output = Void; type Error = Void; - type Future = future::Empty; + type Future = future::Pending>; fn upgrade_outbound(self, _: Negotiated, _: Self::Info) -> Self::Future { - future::empty() + future::pending() } } - diff --git a/core/src/upgrade/either.rs b/core/src/upgrade/either.rs index bf3d86b8..6eb99bb3 100644 --- a/core/src/upgrade/either.rs +++ b/core/src/upgrade/either.rs @@ -19,10 +19,10 @@ // DEALINGS IN THE SOFTWARE. use crate::{ + Negotiated, either::{EitherOutput, EitherError, EitherFuture2, EitherName}, upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo} }; -use multistream_select::Negotiated; /// A type to represent two possible upgrade types (inbound or outbound). #[derive(Debug, Clone)] @@ -50,7 +50,9 @@ where impl InboundUpgrade for EitherUpgrade where A: InboundUpgrade, + >::Future: Unpin, B: InboundUpgrade, + >::Future: Unpin, { type Output = EitherOutput; type Error = EitherError; @@ -68,7 +70,9 @@ where impl OutboundUpgrade for EitherUpgrade where A: OutboundUpgrade, + >::Future: Unpin, B: OutboundUpgrade, + >::Future: Unpin, { type Output = EitherOutput; type Error = EitherError; diff --git a/core/src/upgrade/map.rs b/core/src/upgrade/map.rs index ee17b845..ebbd9a24 100644 --- a/core/src/upgrade/map.rs +++ b/core/src/upgrade/map.rs @@ -18,9 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::Negotiated; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; -use futures::{prelude::*, try_ready}; -use multistream_select::Negotiated; +use futures::prelude::*; +use std::{pin::Pin, task::Context, task::Poll}; /// Wraps around an upgrade and applies a closure to the output. #[derive(Debug, Clone)] @@ -47,6 +48,7 @@ where impl InboundUpgrade for MapInboundUpgrade where U: InboundUpgrade, + U::Future: Unpin, F: FnOnce(U::Output) -> T { type Output = T; @@ -63,7 +65,8 @@ where impl OutboundUpgrade for MapInboundUpgrade where - U: OutboundUpgrade + U: OutboundUpgrade, + U::Future: Unpin, { type Output = U::Output; type Error = U::Error; @@ -98,7 +101,8 @@ where impl InboundUpgrade for MapOutboundUpgrade where - U: InboundUpgrade + U: InboundUpgrade, + U::Future: Unpin, { type Output = U::Output; type Error = U::Error; @@ -112,6 +116,7 @@ where impl OutboundUpgrade for MapOutboundUpgrade where U: OutboundUpgrade, + U::Future: Unpin, F: FnOnce(U::Output) -> T { type Output = T; @@ -151,6 +156,7 @@ where impl InboundUpgrade for MapInboundUpgradeErr where U: InboundUpgrade, + U::Future: Unpin, F: FnOnce(U::Error) -> T { type Output = U::Output; @@ -167,7 +173,8 @@ where impl OutboundUpgrade for MapInboundUpgradeErr where - U: OutboundUpgrade + U: OutboundUpgrade, + U::Future: Unpin, { type Output = U::Output; type Error = U::Error; @@ -203,6 +210,7 @@ where impl OutboundUpgrade for MapOutboundUpgradeErr where U: OutboundUpgrade, + U::Future: Unpin, F: FnOnce(U::Error) -> T { type Output = U::Output; @@ -235,18 +243,25 @@ pub struct MapFuture { map: Option, } +impl Unpin for MapFuture { +} + impl Future for MapFuture where - TInnerFut: Future, + TInnerFut: TryFuture + Unpin, TMap: FnOnce(TIn) -> TOut, { - type Item = TOut; - type Error = TInnerFut::Error; + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let item = match TryFuture::try_poll(Pin::new(&mut self.inner), cx) { + Poll::Ready(Ok(v)) => v, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; - fn poll(&mut self) -> Poll { - let item = try_ready!(self.inner.poll()); let map = self.map.take().expect("Future has already finished"); - Ok(Async::Ready(map(item))) + Poll::Ready(Ok(map(item))) } } @@ -255,21 +270,23 @@ pub struct MapErrFuture { fun: Option, } +impl Unpin for MapErrFuture { +} + impl Future for MapErrFuture where - T: Future, + T: TryFuture + Unpin, F: FnOnce(E) -> A, { - type Item = T::Item; - type Error = A; + type Output = Result; - fn poll(&mut self) -> Poll { - match self.fut.poll() { - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(x)) => Ok(Async::Ready(x)), - Err(e) => { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match TryFuture::try_poll(Pin::new(&mut self.fut), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(x)) => Poll::Ready(Ok(x)), + Poll::Ready(Err(e)) => { let f = self.fun.take().expect("Future has not resolved yet"); - Err(f(e)) + Poll::Ready(Err(f(e))) } } } diff --git a/core/src/upgrade/mod.rs b/core/src/upgrade/mod.rs index 7403655f..14f0d9aa 100644 --- a/core/src/upgrade/mod.rs +++ b/core/src/upgrade/mod.rs @@ -68,7 +68,8 @@ mod transfer; use futures::future::Future; -pub use multistream_select::{Negotiated, NegotiatedComplete, NegotiationError, ProtocolError}; +pub use crate::Negotiated; +pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError}; pub use self::{ apply::{apply, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply}, denied::DeniedUpgrade, @@ -77,7 +78,7 @@ pub use self::{ map::{MapInboundUpgrade, MapOutboundUpgrade, MapInboundUpgradeErr, MapOutboundUpgradeErr}, optional::OptionalUpgrade, select::SelectUpgrade, - transfer::{write_one, WriteOne, read_one, ReadOne, read_one_then, ReadOneThen, ReadOneError, request_response, RequestResponse, read_respond, ReadRespond}, + transfer::{write_one, write_with_len_prefix, write_varint, read_one, ReadOneError, read_varint}, }; /// Types serving as protocol names. @@ -143,7 +144,8 @@ pub trait InboundUpgrade: UpgradeInfo { /// Possible error during the handshake. type Error; /// Future that performs the handshake with the remote. - type Future: Future; + // TODO: remove Unpin + type Future: Future> + Unpin; /// After we have determined that the remote supports one of the protocols we support, this /// method is called to start the handshake. @@ -183,7 +185,8 @@ pub trait OutboundUpgrade: UpgradeInfo { /// Possible error during the handshake. type Error; /// Future that performs the handshake with the remote. - type Future: Future; + // TODO: remove Unpin + type Future: Future> + Unpin; /// After we have determined that the remote supports one of the protocols we support, this /// method is called to start the handshake. diff --git a/core/src/upgrade/optional.rs b/core/src/upgrade/optional.rs index b822d5b9..618f8579 100644 --- a/core/src/upgrade/optional.rs +++ b/core/src/upgrade/optional.rs @@ -19,7 +19,7 @@ // DEALINGS IN THE SOFTWARE. use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; -use multistream_select::Negotiated; +use crate::Negotiated; /// Upgrade that can be disabled at runtime. /// diff --git a/core/src/upgrade/select.rs b/core/src/upgrade/select.rs index 61c3ec5e..8adcbabc 100644 --- a/core/src/upgrade/select.rs +++ b/core/src/upgrade/select.rs @@ -19,10 +19,10 @@ // DEALINGS IN THE SOFTWARE. use crate::{ + Negotiated, either::{EitherOutput, EitherError, EitherFuture2, EitherName}, upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo} }; -use multistream_select::Negotiated; /// Upgrade that combines two upgrades into one. Supports all the protocols supported by either /// sub-upgrade. @@ -59,7 +59,9 @@ where impl InboundUpgrade for SelectUpgrade where A: InboundUpgrade, + >::Future: Unpin, B: InboundUpgrade, + >::Future: Unpin, { type Output = EitherOutput; type Error = EitherError; @@ -76,7 +78,9 @@ where impl OutboundUpgrade for SelectUpgrade where A: OutboundUpgrade, + >::Future: Unpin, B: OutboundUpgrade, + >::Future: Unpin, { type Output = EitherOutput; type Error = EitherError; diff --git a/core/src/upgrade/transfer.rs b/core/src/upgrade/transfer.rs index dd5aebcb..57a92f0e 100644 --- a/core/src/upgrade/transfer.rs +++ b/core/src/upgrade/transfer.rs @@ -20,104 +20,93 @@ //! Contains some helper futures for creating upgrades. -use futures::{prelude::*, try_ready}; -use std::{cmp, error, fmt, io::Cursor, mem}; -use tokio_io::{io, AsyncRead, AsyncWrite}; +use futures::prelude::*; +use std::{error, fmt, io}; + +// TODO: these methods could be on an Ext trait to AsyncWrite /// Send a message to the given socket, then shuts down the writing side. /// /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is /// > compatible with what `read_one` expects. -pub fn write_one(socket: TSocket, data: TData) -> WriteOne -where - TSocket: AsyncWrite, - TData: AsRef<[u8]>, +pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) + -> Result<(), io::Error> { - let len_data = build_int_buffer(data.as_ref().len()); - WriteOne { - inner: WriteOneInner::WriteLen(io::write_all(socket, len_data), data), - } + write_varint(socket, data.as_ref().len()).await?; + socket.write_all(data.as_ref()).await?; + socket.close().await?; + Ok(()) } -/// Builds a buffer that contains the given integer encoded as variable-length. -fn build_int_buffer(num: usize) -> io::Window<[u8; 10]> { - let mut len_data = unsigned_varint::encode::u64_buffer(); - let encoded_len = unsigned_varint::encode::u64(num as u64, &mut len_data).len(); - let mut len_data = io::Window::new(len_data); - len_data.set_end(encoded_len); - len_data -} - -/// Future that makes `write_one` work. -#[derive(Debug)] -pub struct WriteOne> { - inner: WriteOneInner, -} - -#[derive(Debug)] -enum WriteOneInner { - /// We need to write the data length to the socket. - WriteLen(io::WriteAll>, TData), - /// We need to write the actual data to the socket. - Write(io::WriteAll), - /// We need to shut down the socket. - Shutdown(io::Shutdown), - /// A problem happened during the processing. - Poisoned, -} - -impl Future for WriteOne -where - TSocket: AsyncWrite, - TData: AsRef<[u8]>, +/// Send a message to the given socket with a length prefix appended to it. Also flushes the socket. +/// +/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is +/// > compatible with what `read_one` expects. +pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) + -> Result<(), io::Error> { - type Item = (); - type Error = std::io::Error; - - fn poll(&mut self) -> Poll { - Ok(self.inner.poll()?.map(|_socket| ())) - } + write_varint(socket, data.as_ref().len()).await?; + socket.write_all(data.as_ref()).await?; + socket.flush().await?; + Ok(()) } -impl Future for WriteOneInner -where - TSocket: AsyncWrite, - TData: AsRef<[u8]>, +/// Writes a variable-length integer to the `socket`. +/// +/// > **Note**: Does **NOT** flush the socket. +pub async fn write_varint(socket: &mut (impl AsyncWrite + Unpin), len: usize) + -> Result<(), io::Error> { - type Item = TSocket; - type Error = std::io::Error; + let mut len_data = unsigned_varint::encode::usize_buffer(); + let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len(); + socket.write_all(&len_data[..encoded_len]).await?; + Ok(()) +} - fn poll(&mut self) -> Poll { - loop { - match mem::replace(self, WriteOneInner::Poisoned) { - WriteOneInner::WriteLen(mut inner, data) => match inner.poll()? { - Async::Ready((socket, _)) => { - *self = WriteOneInner::Write(io::write_all(socket, data)); - } - Async::NotReady => { - *self = WriteOneInner::WriteLen(inner, data); - } - }, - WriteOneInner::Write(mut inner) => match inner.poll()? { - Async::Ready((socket, _)) => { - *self = WriteOneInner::Shutdown(tokio_io::io::shutdown(socket)); - } - Async::NotReady => { - *self = WriteOneInner::Write(inner); - } - }, - WriteOneInner::Shutdown(ref mut inner) => { - let socket = try_ready!(inner.poll()); - return Ok(Async::Ready(socket)); +/// Reads a variable-length integer from the `socket`. +/// +/// As a special exception, if the `socket` is empty and EOFs right at the beginning, then we +/// return `Ok(0)`. +/// +/// > **Note**: This function reads bytes one by one from the `socket`. It is therefore encouraged +/// > to use some sort of buffering mechanism. +pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buffer = unsigned_varint::encode::usize_buffer(); + let mut buffer_len = 0; + + loop { + match socket.read(&mut buffer[buffer_len..buffer_len+1]).await? { + 0 => { + // Reaching EOF before finishing to read the length is an error, unless the EOF is + // at the very beginning of the substream, in which case we assume that the data is + // empty. + if buffer_len == 0 { + return Ok(0); + } else { + return Err(io::ErrorKind::UnexpectedEof.into()); } - WriteOneInner::Poisoned => panic!(), } + n => debug_assert_eq!(n, 1), + } + + buffer_len += 1; + + match unsigned_varint::decode::usize(&buffer[..buffer_len]) { + Ok((len, _)) => return Ok(len), + Err(unsigned_varint::decode::Error::Overflow) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "overflow in variable-length integer" + )); + } + // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it + // Err(unsigned_varint::decode::Error::Insufficient) => {} + Err(_) => {} } } } -/// Reads a message from the given socket. Only one message is processed and the socket is dropped, -/// because we assume that the socket will not send anything more. +/// Reads a length-prefixed message from the given socket. /// /// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is /// necessary in order to avoid DoS attacks where the remote sends us a message of several @@ -125,137 +114,20 @@ where /// /// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is /// > compatible with what `write_one` does. -pub fn read_one( - socket: TSocket, - max_size: usize, -) -> ReadOne +pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize) + -> Result, ReadOneError> { - ReadOne { - inner: ReadOneInner::ReadLen { - socket, - len_buf: Cursor::new([0; 10]), - max_size, - }, + let len = read_varint(socket).await?; + if len > max_size { + return Err(ReadOneError::TooLarge { + requested: len, + max: max_size, + }); } -} -/// Future that makes `read_one` work. -#[derive(Debug)] -pub struct ReadOne { - inner: ReadOneInner, -} - -#[derive(Debug)] -enum ReadOneInner { - // We need to read the data length from the socket. - ReadLen { - socket: TSocket, - /// A small buffer where we will right the variable-length integer representing the - /// length of the actual packet. - len_buf: Cursor<[u8; 10]>, - max_size: usize, - }, - // We need to read the actual data from the socket. - ReadRest(io::ReadExact>>), - /// A problem happened during the processing. - Poisoned, -} - -impl Future for ReadOne -where - TSocket: AsyncRead, -{ - type Item = Vec; - type Error = ReadOneError; - - fn poll(&mut self) -> Poll { - Ok(self.inner.poll()?.map(|(_, out)| out)) - } -} - -impl Future for ReadOneInner -where - TSocket: AsyncRead, -{ - type Item = (TSocket, Vec); - type Error = ReadOneError; - - fn poll(&mut self) -> Poll { - loop { - match mem::replace(self, ReadOneInner::Poisoned) { - ReadOneInner::ReadLen { - mut socket, - mut len_buf, - max_size, - } => { - match socket.read_buf(&mut len_buf)? { - Async::Ready(num_read) => { - // Reaching EOF before finishing to read the length is an error, unless - // the EOF is at the very beginning of the substream, in which case we - // assume that the data is empty. - if num_read == 0 { - if len_buf.position() == 0 { - return Ok(Async::Ready((socket, Vec::new()))); - } else { - return Err(ReadOneError::Io( - std::io::ErrorKind::UnexpectedEof.into(), - )); - } - } - - let len_buf_with_data = - &len_buf.get_ref()[..len_buf.position() as usize]; - if let Ok((len, data_start)) = - unsigned_varint::decode::usize(len_buf_with_data) - { - if len >= max_size { - return Err(ReadOneError::TooLarge { - requested: len, - max: max_size, - }); - } - - // Create `data_buf` containing the start of the data that was - // already in `len_buf`. - let n = cmp::min(data_start.len(), len); - let mut data_buf = vec![0; len]; - data_buf[.. n].copy_from_slice(&data_start[.. n]); - let mut data_buf = io::Window::new(data_buf); - data_buf.set_start(data_start.len()); - *self = ReadOneInner::ReadRest(io::read_exact(socket, data_buf)); - } else { - *self = ReadOneInner::ReadLen { - socket, - len_buf, - max_size, - }; - } - } - Async::NotReady => { - *self = ReadOneInner::ReadLen { - socket, - len_buf, - max_size, - }; - return Ok(Async::NotReady); - } - } - } - ReadOneInner::ReadRest(mut inner) => { - match inner.poll()? { - Async::Ready((socket, data)) => { - return Ok(Async::Ready((socket, data.into_inner()))); - } - Async::NotReady => { - *self = ReadOneInner::ReadRest(inner); - return Ok(Async::NotReady); - } - } - } - ReadOneInner::Poisoned => panic!(), - } - } - } + let mut buf = vec![0; len]; + socket.read_exact(&mut buf).await?; + Ok(buf) } /// Error while reading one message. @@ -296,194 +168,10 @@ impl error::Error for ReadOneError { } } -/// Similar to `read_one`, but applies a transformation on the output buffer. -/// -/// > **Note**: The `param` parameter is an arbitrary value that will be passed back to `then`. -/// > This parameter is normally not necessary, as we could just pass a closure that has -/// > ownership of any data we want. In practice, though, this would make the -/// > `ReadRespond` type impossible to express as a concrete type. Once the `impl Trait` -/// > syntax is allowed within traits, we can remove this parameter. -pub fn read_one_then( - socket: TSocket, - max_size: usize, - param: TParam, - then: TThen, -) -> ReadOneThen -where - TSocket: AsyncRead, - TThen: FnOnce(Vec, TParam) -> Result, - TErr: From, -{ - ReadOneThen { - inner: read_one(socket, max_size), - then: Some((param, then)), - } -} - -/// Future that makes `read_one_then` work. -#[derive(Debug)] -pub struct ReadOneThen { - inner: ReadOne, - then: Option<(TParam, TThen)>, -} - -impl Future for ReadOneThen -where - TSocket: AsyncRead, - TThen: FnOnce(Vec, TParam) -> Result, - TErr: From, -{ - type Item = TOut; - type Error = TErr; - - fn poll(&mut self) -> Poll { - match self.inner.poll()? { - Async::Ready(buffer) => { - let (param, then) = self.then.take() - .expect("Future was polled after it was finished"); - Ok(Async::Ready(then(buffer, param)?)) - }, - Async::NotReady => Ok(Async::NotReady), - } - } -} - -/// Similar to `read_one`, but applies a transformation on the output buffer. -/// -/// > **Note**: The `param` parameter is an arbitrary value that will be passed back to `then`. -/// > This parameter is normally not necessary, as we could just pass a closure that has -/// > ownership of any data we want. In practice, though, this would make the -/// > `ReadRespond` type impossible to express as a concrete type. Once the `impl Trait` -/// > syntax is allowed within traits, we can remove this parameter. -pub fn read_respond( - socket: TSocket, - max_size: usize, - param: TParam, - then: TThen, -) -> ReadRespond -where - TSocket: AsyncRead, - TThen: FnOnce(TSocket, Vec, TParam) -> Result, - TErr: From, -{ - ReadRespond { - inner: read_one(socket, max_size).inner, - then: Some((then, param)), - } -} - -/// Future that makes `read_respond` work. -#[derive(Debug)] -pub struct ReadRespond { - inner: ReadOneInner, - then: Option<(TThen, TParam)>, -} - -impl Future for ReadRespond -where - TSocket: AsyncRead, - TThen: FnOnce(TSocket, Vec, TParam) -> Result, - TErr: From, -{ - type Item = TOut; - type Error = TErr; - - fn poll(&mut self) -> Poll { - match self.inner.poll()? { - Async::Ready((socket, buffer)) => { - let (then, param) = self.then.take().expect("Future was polled after it was finished"); - Ok(Async::Ready(then(socket, buffer, param)?)) - }, - Async::NotReady => Ok(Async::NotReady), - } - } -} - -/// Send a message to the given socket, then shuts down the writing side, then reads an answer. -/// -/// This combines `write_one` followed with `read_one_then`. -/// -/// > **Note**: The `param` parameter is an arbitrary value that will be passed back to `then`. -/// > This parameter is normally not necessary, as we could just pass a closure that has -/// > ownership of any data we want. In practice, though, this would make the -/// > `ReadRespond` type impossible to express as a concrete type. Once the `impl Trait` -/// > syntax is allowed within traits, we can remove this parameter. -pub fn request_response( - socket: TSocket, - data: TData, - max_size: usize, - param: TParam, - then: TThen, -) -> RequestResponse -where - TSocket: AsyncRead + AsyncWrite, - TData: AsRef<[u8]>, - TThen: FnOnce(Vec, TParam) -> Result, -{ - RequestResponse { - inner: RequestResponseInner::Write(write_one(socket, data).inner, max_size, param, then), - } -} - -/// Future that makes `request_response` work. -#[derive(Debug)] -pub struct RequestResponse> { - inner: RequestResponseInner, -} - -#[derive(Debug)] -enum RequestResponseInner { - // We need to write data to the socket. - Write(WriteOneInner, usize, TParam, TThen), - // We need to read the message. - Read(ReadOneThen), - // An error happened during the processing. - Poisoned, -} - -impl Future for RequestResponse -where - TSocket: AsyncRead + AsyncWrite, - TData: AsRef<[u8]>, - TThen: FnOnce(Vec, TParam) -> Result, - TErr: From, -{ - type Item = TOut; - type Error = TErr; - - fn poll(&mut self) -> Poll { - loop { - match mem::replace(&mut self.inner, RequestResponseInner::Poisoned) { - RequestResponseInner::Write(mut inner, max_size, param, then) => { - match inner.poll().map_err(ReadOneError::Io)? { - Async::Ready(socket) => { - self.inner = - RequestResponseInner::Read(read_one_then(socket, max_size, param, then)); - } - Async::NotReady => { - self.inner = RequestResponseInner::Write(inner, max_size, param, then); - return Ok(Async::NotReady); - } - } - } - RequestResponseInner::Read(mut inner) => match inner.poll()? { - Async::Ready(packet) => return Ok(Async::Ready(packet)), - Async::NotReady => { - self.inner = RequestResponseInner::Read(inner); - return Ok(Async::NotReady); - } - }, - RequestResponseInner::Poisoned => panic!(), - }; - } - } -} - #[cfg(test)] mod tests { use super::*; use std::io::{self, Cursor}; - use tokio::runtime::current_thread::Runtime; #[test] fn write_one_works() { @@ -492,14 +180,17 @@ mod tests { .collect::>(); let mut out = vec![0; 10_000]; - let future = write_one(Cursor::new(&mut out[..]), data.clone()); - Runtime::new().unwrap().block_on(future).unwrap(); + futures::executor::block_on( + write_one(&mut Cursor::new(&mut out[..]), data.clone()) + ).unwrap(); let (out_len, out_data) = unsigned_varint::decode::usize(&out).unwrap(); assert_eq!(out_len, data.len()); assert_eq!(&out_data[..out_len], &data[..]); } + // TODO: rewrite these tests +/* #[test] fn read_one_works() { let original_data = (0..rand::random::() % 10_000) @@ -517,7 +208,7 @@ mod tests { Ok(()) }); - Runtime::new().unwrap().block_on(future).unwrap(); + futures::executor::block_on(future).unwrap(); } #[test] @@ -527,7 +218,7 @@ mod tests { Ok(()) }); - Runtime::new().unwrap().block_on(future).unwrap(); + futures::executor::block_on(future).unwrap(); } #[test] @@ -542,7 +233,7 @@ mod tests { Ok(()) }); - match Runtime::new().unwrap().block_on(future) { + match futures::executor::block_on(future) { Err(ReadOneError::TooLarge { .. }) => (), _ => panic!(), } @@ -555,7 +246,7 @@ mod tests { Ok(()) }); - Runtime::new().unwrap().block_on(future).unwrap(); + futures::executor::block_on(future).unwrap(); } #[test] @@ -564,9 +255,9 @@ mod tests { unreachable!() }); - match Runtime::new().unwrap().block_on(future) { + match futures::executor::block_on(future) { Err(ReadOneError::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => (), _ => panic!() } - } + }*/ } diff --git a/core/tests/network_dial_error.rs b/core/tests/network_dial_error.rs index cc9c3dfa..4cd0b39b 100644 --- a/core/tests/network_dial_error.rs +++ b/core/tests/network_dial_error.rs @@ -20,7 +20,7 @@ mod util; -use futures::{future, prelude::*}; +use futures::prelude::*; use libp2p_core::identity; use libp2p_core::multiaddr::multiaddr; use libp2p_core::nodes::network::{Network, NetworkEvent, NetworkReachError, PeerState, UnknownPeerDialErr, IncomingError}; @@ -47,7 +47,7 @@ impl Default for TestHandler { impl ProtocolsHandler for TestHandler where - TSubstream: tokio_io::AsyncRead + tokio_io::AsyncWrite + TSubstream: futures::PollRead + futures::PollWrite { type InEvent = (); // TODO: cannot be Void (https://github.com/servo/rust-smallvec/issues/139) type OutEvent = (); // TODO: cannot be Void (https://github.com/servo/rust-smallvec/issues/139) @@ -82,8 +82,8 @@ where fn connection_keep_alive(&self) -> KeepAlive { KeepAlive::No } - fn poll(&mut self) -> Poll, Self::Error> { - Ok(Async::NotReady) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll, Self::Error> { + Poll::Pending } } @@ -114,7 +114,7 @@ fn deny_incoming_connec() { swarm1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); let address = - if let Async::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm1.poll() { + if let Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm1.poll() { listen_addr } else { panic!("Was expecting the listen address to be reported") @@ -125,15 +125,15 @@ fn deny_incoming_connec() { .into_not_connected().unwrap() .connect(address.clone(), TestHandler::default().into_node_handler_builder()); - let future = future::poll_fn(|| -> Poll<(), io::Error> { + let future = future::poll_fn(|| -> Poll> { match swarm1.poll() { - Async::Ready(NetworkEvent::IncomingConnection(inc)) => drop(inc), - Async::Ready(_) => unreachable!(), - Async::NotReady => (), + Poll::Ready(NetworkEvent::IncomingConnection(inc)) => drop(inc), + Poll::Ready(_) => unreachable!(), + Poll::Pending => (), } match swarm2.poll() { - Async::Ready(NetworkEvent::DialError { + Poll::Ready(NetworkEvent::DialError { new_state: PeerState::NotConnected, peer_id, multiaddr, @@ -141,13 +141,13 @@ fn deny_incoming_connec() { }) => { assert_eq!(peer_id, *swarm1.local_peer_id()); assert_eq!(multiaddr, address); - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); }, - Async::Ready(_) => unreachable!(), - Async::NotReady => (), + Poll::Ready(_) => unreachable!(), + Poll::Pending => (), } - Ok(Async::NotReady) + Poll::Pending }); tokio::runtime::current_thread::Runtime::new().unwrap().block_on(future).unwrap(); @@ -185,7 +185,7 @@ fn dial_self() { let (address, mut swarm) = future::lazy(move || { - if let Async::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm.poll() { + if let Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm.poll() { Ok::<_, void::Void>((listen_addr, swarm)) } else { panic!("Was expecting the listen address to be reported") @@ -198,10 +198,10 @@ fn dial_self() { let mut got_dial_err = false; let mut got_inc_err = false; - let future = future::poll_fn(|| -> Poll<(), io::Error> { + let future = future::poll_fn(|| -> Poll> { loop { match swarm.poll() { - Async::Ready(NetworkEvent::UnknownPeerDialError { + Poll::Ready(NetworkEvent::UnknownPeerDialError { multiaddr, error: UnknownPeerDialErr::FoundLocalPeerId, handler: _ @@ -210,10 +210,10 @@ fn dial_self() { assert!(!got_dial_err); got_dial_err = true; if got_inc_err { - return Ok(Async::Ready(())); + return Ok(Poll::Ready(())); } }, - Async::Ready(NetworkEvent::IncomingConnectionError { + Poll::Ready(NetworkEvent::IncomingConnectionError { local_addr, send_back_addr: _, error: IncomingError::FoundLocalPeerId @@ -222,17 +222,17 @@ fn dial_self() { assert!(!got_inc_err); got_inc_err = true; if got_dial_err { - return Ok(Async::Ready(())); + return Ok(Poll::Ready(())); } }, - Async::Ready(NetworkEvent::IncomingConnection(inc)) => { + Poll::Ready(NetworkEvent::IncomingConnection(inc)) => { assert_eq!(*inc.local_addr(), address); inc.accept(TestHandler::default().into_node_handler_builder()); }, - Async::Ready(ev) => { + Poll::Ready(ev) => { panic!("Unexpected event: {:?}", ev) } - Async::NotReady => break Ok(Async::NotReady), + Poll::Pending => break Poll::Pending, } } }); @@ -288,10 +288,10 @@ fn multiple_addresses_err() { .connect_iter(addresses.clone(), TestHandler::default().into_node_handler_builder()) .unwrap(); - let future = future::poll_fn(|| -> Poll<(), io::Error> { + let future = future::poll_fn(|| -> Poll> { loop { match swarm.poll() { - Async::Ready(NetworkEvent::DialError { + Poll::Ready(NetworkEvent::DialError { new_state, peer_id, multiaddr, @@ -302,7 +302,7 @@ fn multiple_addresses_err() { assert_eq!(multiaddr, expected); if addresses.is_empty() { assert_eq!(new_state, PeerState::NotConnected); - return Ok(Async::Ready(())); + return Ok(Poll::Ready(())); } else { match new_state { PeerState::Dialing { num_pending_addresses } => { @@ -312,8 +312,8 @@ fn multiple_addresses_err() { } } }, - Async::Ready(_) => unreachable!(), - Async::NotReady => break Ok(Async::NotReady), + Poll::Ready(_) => unreachable!(), + Poll::Pending => break Poll::Pending, } } }); diff --git a/core/tests/network_simult.rs b/core/tests/network_simult.rs index 958631b5..785ae1a7 100644 --- a/core/tests/network_simult.rs +++ b/core/tests/network_simult.rs @@ -20,7 +20,7 @@ mod util; -use futures::{future, prelude::*}; +use futures::prelude::*; use libp2p_core::{identity, upgrade, Transport}; use libp2p_core::nodes::{Network, NetworkEvent, Peer}; use libp2p_core::nodes::network::IncomingError; @@ -45,7 +45,7 @@ impl Default for TestHandler { impl ProtocolsHandler for TestHandler where - TSubstream: tokio_io::AsyncRead + tokio_io::AsyncWrite + TSubstream: futures::PollRead + futures::PollWrite { type InEvent = (); // TODO: cannot be Void (https://github.com/servo/rust-smallvec/issues/139) type OutEvent = (); // TODO: cannot be Void (https://github.com/servo/rust-smallvec/issues/139) @@ -80,8 +80,8 @@ where fn connection_keep_alive(&self) -> KeepAlive { KeepAlive::Yes } - fn poll(&mut self) -> Poll, Self::Error> { - Ok(Async::NotReady) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll, Self::Error> { + Poll::Pending } } @@ -142,14 +142,14 @@ fn raw_swarm_simultaneous_connect() { let (swarm1_listen_addr, swarm2_listen_addr, mut swarm1, mut swarm2) = future::lazy(move || { let swarm1_listen_addr = - if let Async::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm1.poll() { + if let Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm1.poll() { listen_addr } else { panic!("Was expecting the listen address to be reported") }; let swarm2_listen_addr = - if let Async::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm2.poll() { + if let Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm2.poll() { listen_addr } else { panic!("Was expecting the listen address to be reported") @@ -179,7 +179,7 @@ fn raw_swarm_simultaneous_connect() { if swarm1_step == 0 { match swarm1_dial_start.poll().unwrap() { - Async::Ready(_) => { + Poll::Ready(_) => { let handler = TestHandler::default().into_node_handler_builder(); swarm1.peer(swarm2.local_peer_id().clone()) .into_not_connected() @@ -187,13 +187,13 @@ fn raw_swarm_simultaneous_connect() { .connect(swarm2_listen_addr.clone(), handler); swarm1_step = 1; }, - Async::NotReady => swarm1_not_ready = true, + Poll::Pending => swarm1_not_ready = true, } } if swarm2_step == 0 { match swarm2_dial_start.poll().unwrap() { - Async::Ready(_) => { + Poll::Ready(_) => { let handler = TestHandler::default().into_node_handler_builder(); swarm2.peer(swarm1.local_peer_id().clone()) .into_not_connected() @@ -201,79 +201,79 @@ fn raw_swarm_simultaneous_connect() { .connect(swarm1_listen_addr.clone(), handler); swarm2_step = 1; }, - Async::NotReady => swarm2_not_ready = true, + Poll::Pending => swarm2_not_ready = true, } } if rand::random::() < 0.1 { match swarm1.poll() { - Async::Ready(NetworkEvent::IncomingConnectionError { + Poll::Ready(NetworkEvent::IncomingConnectionError { error: IncomingError::DeniedLowerPriority, .. }) => { assert_eq!(swarm1_step, 2); swarm1_step = 3; }, - Async::Ready(NetworkEvent::Connected { conn_info, .. }) => { + Poll::Ready(NetworkEvent::Connected { conn_info, .. }) => { assert_eq!(conn_info, *swarm2.local_peer_id()); if swarm1_step == 0 { // The connection was established before // swarm1 started dialing; discard the test run. - return Ok(Async::Ready(false)) + return Ok(Poll::Ready(false)) } assert_eq!(swarm1_step, 1); swarm1_step = 2; }, - Async::Ready(NetworkEvent::Replaced { new_info, .. }) => { + Poll::Ready(NetworkEvent::Replaced { new_info, .. }) => { assert_eq!(new_info, *swarm2.local_peer_id()); assert_eq!(swarm1_step, 2); swarm1_step = 3; }, - Async::Ready(NetworkEvent::IncomingConnection(inc)) => { + Poll::Ready(NetworkEvent::IncomingConnection(inc)) => { inc.accept(TestHandler::default().into_node_handler_builder()); }, - Async::Ready(ev) => panic!("swarm1: unexpected event: {:?}", ev), - Async::NotReady => swarm1_not_ready = true, + Poll::Ready(ev) => panic!("swarm1: unexpected event: {:?}", ev), + Poll::Pending => swarm1_not_ready = true, } } if rand::random::() < 0.1 { match swarm2.poll() { - Async::Ready(NetworkEvent::IncomingConnectionError { + Poll::Ready(NetworkEvent::IncomingConnectionError { error: IncomingError::DeniedLowerPriority, .. }) => { assert_eq!(swarm2_step, 2); swarm2_step = 3; }, - Async::Ready(NetworkEvent::Connected { conn_info, .. }) => { + Poll::Ready(NetworkEvent::Connected { conn_info, .. }) => { assert_eq!(conn_info, *swarm1.local_peer_id()); if swarm2_step == 0 { // The connection was established before // swarm2 started dialing; discard the test run. - return Ok(Async::Ready(false)) + return Ok(Poll::Ready(false)) } assert_eq!(swarm2_step, 1); swarm2_step = 2; }, - Async::Ready(NetworkEvent::Replaced { new_info, .. }) => { + Poll::Ready(NetworkEvent::Replaced { new_info, .. }) => { assert_eq!(new_info, *swarm1.local_peer_id()); assert_eq!(swarm2_step, 2); swarm2_step = 3; }, - Async::Ready(NetworkEvent::IncomingConnection(inc)) => { + Poll::Ready(NetworkEvent::IncomingConnection(inc)) => { inc.accept(TestHandler::default().into_node_handler_builder()); }, - Async::Ready(ev) => panic!("swarm2: unexpected event: {:?}", ev), - Async::NotReady => swarm2_not_ready = true, + Poll::Ready(ev) => panic!("swarm2: unexpected event: {:?}", ev), + Poll::Pending => swarm2_not_ready = true, } } // TODO: make sure that >= 5 is correct if swarm1_step + swarm2_step >= 5 { - return Ok(Async::Ready(true)); + return Ok(Poll::Ready(true)); } if swarm1_not_ready && swarm2_not_ready { - return Ok(Async::NotReady); + return Poll::Pending; } } }); diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index 61b96f35..96515da4 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -20,8 +20,7 @@ mod util; -use futures::future::Future; -use futures::stream::Stream; +use futures::prelude::*; use libp2p_core::identity; use libp2p_core::transport::{Transport, MemoryTransport, ListenerEvent}; use libp2p_core::upgrade::{UpgradeInfo, Negotiated, InboundUpgrade, OutboundUpgrade}; @@ -30,7 +29,6 @@ use libp2p_secio::SecioConfig; use multiaddr::Multiaddr; use rand::random; use std::io; -use tokio_io::{io as nio, AsyncWrite, AsyncRead}; #[derive(Clone)] struct HelloUpgrade {} diff --git a/core/tests/util.rs b/core/tests/util.rs index b4344282..69b1f936 100644 --- a/core/tests/util.rs +++ b/core/tests/util.rs @@ -3,6 +3,7 @@ use futures::prelude::*; use libp2p_core::muxing::StreamMuxer; +use std::{pin::Pin, task::Context, task::Poll}; pub struct CloseMuxer { state: CloseMuxerState, @@ -26,18 +27,17 @@ where M: StreamMuxer, M::Error: From { - type Item = M; - type Error = M::Error; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { loop { match std::mem::replace(&mut self.state, CloseMuxerState::Done) { CloseMuxerState::Close(muxer) => { if muxer.close()?.is_not_ready() { self.state = CloseMuxerState::Close(muxer); - return Ok(Async::NotReady) + return Poll::Pending } - return Ok(Async::Ready(muxer)) + return Poll::Ready(Ok(muxer)) } CloseMuxerState::Done => panic!() } diff --git a/misc/core-derive/src/lib.rs b/misc/core-derive/src/lib.rs index e6b84a58..da45329e 100644 --- a/misc/core-derive/src/lib.rs +++ b/misc/core-derive/src/lib.rs @@ -381,7 +381,7 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { // If we find a `#[behaviour(poll_method = "poll")]` attribute on the struct, we call // `self.poll()` at the end of the polling. let poll_method = { - let mut poll_method = quote!{Async::NotReady}; + let mut poll_method = quote!{Poll::Pending}; for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { @@ -419,25 +419,25 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { Some(quote!{ loop { match #field_name.poll(poll_params) { - Async::Ready(#network_behaviour_action::GenerateEvent(event)) => { + Poll::Ready(#network_behaviour_action::GenerateEvent(event)) => { #net_behv_event_proc::inject_event(self, event) } - Async::Ready(#network_behaviour_action::DialAddress { address }) => { - return Async::Ready(#network_behaviour_action::DialAddress { address }); + Poll::Ready(#network_behaviour_action::DialAddress { address }) => { + return Poll::Ready(#network_behaviour_action::DialAddress { address }); } - Async::Ready(#network_behaviour_action::DialPeer { peer_id }) => { - return Async::Ready(#network_behaviour_action::DialPeer { peer_id }); + Poll::Ready(#network_behaviour_action::DialPeer { peer_id }) => { + return Poll::Ready(#network_behaviour_action::DialPeer { peer_id }); } - Async::Ready(#network_behaviour_action::SendEvent { peer_id, event }) => { - return Async::Ready(#network_behaviour_action::SendEvent { + Poll::Ready(#network_behaviour_action::SendEvent { peer_id, event }) => { + return Poll::Ready(#network_behaviour_action::SendEvent { peer_id, event: #wrapped_event, }); } - Async::Ready(#network_behaviour_action::ReportObservedAddr { address }) => { - return Async::Ready(#network_behaviour_action::ReportObservedAddr { address }); + Poll::Ready(#network_behaviour_action::ReportObservedAddr { address }) => { + return Poll::Ready(#network_behaviour_action::ReportObservedAddr { address }); } - Async::NotReady => break, + Poll::Pending => break, } } }) @@ -512,10 +512,10 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { } } - fn poll(&mut self, poll_params: &mut impl #poll_parameters) -> ::libp2p::futures::Async<#network_behaviour_action<<::Handler as #protocols_handler>::InEvent, Self::OutEvent>> { + fn poll(&mut self, cx: &mut std::task::Context, poll_params: &mut impl #poll_parameters) -> std::task::Poll<#network_behaviour_action<<::Handler as #protocols_handler>::InEvent, Self::OutEvent>> { use libp2p::futures::prelude::*; #(#poll_stmts)* - let f: ::libp2p::futures::Async<#network_behaviour_action<<::Handler as #protocols_handler>::InEvent, Self::OutEvent>> = #poll_method; + let f: std::task::Poll<#network_behaviour_action<<::Handler as #protocols_handler>::InEvent, Self::OutEvent>> = #poll_method; f } } diff --git a/misc/mdns/Cargo.toml b/misc/mdns/Cargo.toml index 7fb84788..e532e865 100644 --- a/misc/mdns/Cargo.toml +++ b/misc/mdns/Cargo.toml @@ -10,9 +10,10 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] +async-std = "0.99" data-encoding = "2.0" dns-parser = "0.8" -futures = "0.1" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../../core" } libp2p-swarm = { version = "0.2.0", path = "../../swarm" } log = "0.4" @@ -20,11 +21,5 @@ multiaddr = { package = "parity-multiaddr", version = "0.5.0", path = "../multia net2 = "0.2" rand = "0.6" smallvec = "0.6" -tokio-io = "0.1" -tokio-reactor = "0.1" -wasm-timer = "0.1" -tokio-udp = "0.1" +wasm-timer = "0.2" void = "1.0" - -[dev-dependencies] -tokio = "0.1" diff --git a/misc/mdns/src/behaviour.rs b/misc/mdns/src/behaviour.rs index 7d933211..cbdd2503 100644 --- a/misc/mdns/src/behaviour.rs +++ b/misc/mdns/src/behaviour.rs @@ -30,8 +30,7 @@ use libp2p_swarm::{ }; use log::warn; use smallvec::SmallVec; -use std::{cmp, fmt, io, iter, marker::PhantomData, time::Duration}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{cmp, fmt, io, iter, marker::PhantomData, pin::Pin, time::Duration, task::Context, task::Poll}; use wasm_timer::{Delay, Instant}; /// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds @@ -57,9 +56,9 @@ pub struct Mdns { impl Mdns { /// Builds a new `Mdns` behaviour. - pub fn new() -> io::Result> { + pub async fn new() -> io::Result> { Ok(Mdns { - service: MdnsService::new()?, + service: MdnsService::new().await?, discovered_nodes: SmallVec::new(), closest_expiration: None, marker: PhantomData, @@ -145,7 +144,7 @@ impl fmt::Debug for ExpiredAddrsIter { impl NetworkBehaviour for Mdns where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin, { type ProtocolsHandler = DummyProtocolsHandler; type OutEvent = MdnsEvent; @@ -177,8 +176,9 @@ where fn poll( &mut self, + cx: &mut Context, params: &mut impl PollParameters, - ) -> Async< + ) -> Poll< NetworkBehaviourAction< ::InEvent, Self::OutEvent, @@ -186,8 +186,8 @@ where > { // Remove expired peers. if let Some(ref mut closest_expiration) = self.closest_expiration { - match closest_expiration.poll() { - Ok(Async::Ready(())) => { + match Future::poll(Pin::new(closest_expiration), cx) { + Poll::Ready(Ok(())) => { let now = Instant::now(); let mut expired = SmallVec::<[(PeerId, Multiaddr); 4]>::new(); while let Some(pos) = self.discovered_nodes.iter().position(|(_, _, exp)| *exp < now) { @@ -200,19 +200,19 @@ where inner: expired.into_iter(), }); - return Async::Ready(NetworkBehaviourAction::GenerateEvent(event)); + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } }, - Ok(Async::NotReady) => (), - Err(err) => warn!("tokio timer has errored: {:?}", err), + Poll::Pending => (), + Poll::Ready(Err(err)) => warn!("tokio timer has errored: {:?}", err), } } // Polling the mDNS service, and obtain the list of nodes discovered this round. let discovered = loop { - let event = match self.service.poll() { - Async::Ready(ev) => ev, - Async::NotReady => return Async::NotReady, + let event = match self.service.poll(cx) { + Poll::Ready(ev) => ev, + Poll::Pending => return Poll::Pending, }; match event { @@ -274,8 +274,8 @@ where .fold(None, |exp, &(_, _, elem_exp)| { Some(exp.map(|exp| cmp::min(exp, elem_exp)).unwrap_or(elem_exp)) }) - .map(Delay::new); - Async::Ready(NetworkBehaviourAction::GenerateEvent(MdnsEvent::Discovered(DiscoveredAddrsIter { + .map(Delay::new_at); + Poll::Ready(NetworkBehaviourAction::GenerateEvent(MdnsEvent::Discovered(DiscoveredAddrsIter { inner: discovered.into_iter(), }))) } diff --git a/misc/mdns/src/service.rs b/misc/mdns/src/service.rs index 1721a656..f3e2ba3f 100644 --- a/misc/mdns/src/service.rs +++ b/misc/mdns/src/service.rs @@ -19,14 +19,13 @@ // DEALINGS IN THE SOFTWARE. use crate::{SERVICE_NAME, META_QUERY_SERVICE, dns}; +use async_std::net::UdpSocket; use dns_parser::{Packet, RData}; -use futures::{prelude::*, task}; +use futures::prelude::*; use libp2p_core::{Multiaddr, PeerId}; use multiaddr::Protocol; -use std::{fmt, io, net::Ipv4Addr, net::SocketAddr, str, time::Duration}; -use tokio_reactor::Handle; -use wasm_timer::{Instant, Interval}; -use tokio_udp::UdpSocket; +use std::{fmt, io, net::Ipv4Addr, net::SocketAddr, pin::Pin, str, task::Context, task::Poll, time::Duration}; +use wasm_timer::Interval; pub use dns::MdnsResponseError; @@ -63,8 +62,8 @@ pub use dns::MdnsResponseError; /// let _future_to_poll = futures::stream::poll_fn(move || -> Poll, io::Error> { /// loop { /// let packet = match service.poll() { -/// Async::Ready(packet) => packet, -/// Async::NotReady => return Ok(Async::NotReady), +/// Poll::Ready(packet) => packet, +/// Poll::Pending => return Poll::Pending, /// }; /// /// match packet { @@ -113,18 +112,18 @@ pub struct MdnsService { impl MdnsService { /// Starts a new mDNS service. #[inline] - pub fn new() -> io::Result { - Self::new_inner(false) + pub async fn new() -> io::Result { + Self::new_inner(false).await } /// Same as `new`, but we don't send automatically send queries on the network. #[inline] - pub fn silent() -> io::Result { - Self::new_inner(true) + pub async fn silent() -> io::Result { + Self::new_inner(true).await } /// Starts a new mDNS service. - fn new_inner(silent: bool) -> io::Result { + async fn new_inner(silent: bool) -> io::Result { let socket = { #[cfg(unix)] fn platform_specific(s: &net2::UdpBuilder) -> io::Result<()> { @@ -139,7 +138,7 @@ impl MdnsService { builder.bind(("0.0.0.0", 5353))? }; - let socket = UdpSocket::from_std(socket, &Handle::default())?; + let socket = UdpSocket::from(socket); socket.set_multicast_loop_v4(true)?; socket.set_multicast_ttl_v4(255)?; // TODO: correct interfaces? @@ -147,8 +146,8 @@ impl MdnsService { Ok(MdnsService { socket, - query_socket: UdpSocket::bind(&From::from(([0, 0, 0, 0], 0)))?, - query_interval: Interval::new(Instant::now(), Duration::from_secs(20)), + query_socket: UdpSocket::bind((Ipv4Addr::from([0u8, 0, 0, 0]), 0u16)).await?, + query_interval: Interval::new(Duration::from_secs(20)), silent, recv_buffer: [0; 2048], send_buffers: Vec::new(), @@ -156,36 +155,28 @@ impl MdnsService { }) } - /// Polls the service for packets. - pub fn poll(&mut self) -> Async> { + pub async fn next_packet(&mut self) -> MdnsPacket { + // TODO: refactor this block // Send a query every time `query_interval` fires. // Note that we don't use a loop here—it is pretty unlikely that we need it, and there is // no point in sending multiple requests in a row. - match self.query_interval.poll() { - Ok(Async::Ready(_)) => { + match Stream::poll_next(Pin::new(&mut self.query_interval), cx) { + Poll::Ready(_) => { if !self.silent { let query = dns::build_query(); self.query_send_buffers.push(query.to_vec()); } } - Ok(Async::NotReady) => (), - _ => unreachable!("A wasm_timer::Interval never errors"), // TODO: is that true? + Poll::Pending => (), }; // Flush the send buffer of the main socket. while !self.send_buffers.is_empty() { let to_send = self.send_buffers.remove(0); - match self - .socket - .poll_send_to(&to_send, &From::from(([224, 0, 0, 251], 5353))) - { - Ok(Async::Ready(bytes_written)) => { + match self.socket.send_to(&to_send, &From::from(([224, 0, 0, 251], 5353))).await { + Ok(bytes_written) => { debug_assert_eq!(bytes_written, to_send.len()); } - Ok(Async::NotReady) => { - self.send_buffers.insert(0, to_send); - break; - } Err(_) => { // Errors are non-fatal because they can happen for example if we lose // connection to the network. @@ -199,17 +190,10 @@ impl MdnsService { // This has to be after the push to `query_send_buffers`. while !self.query_send_buffers.is_empty() { let to_send = self.query_send_buffers.remove(0); - match self - .query_socket - .poll_send_to(&to_send, &From::from(([224, 0, 0, 251], 5353))) - { - Ok(Async::Ready(bytes_written)) => { + match self.socket.send_to(&to_send, &From::from(([224, 0, 0, 251], 5353))).await { + Ok(bytes_written) => { debug_assert_eq!(bytes_written, to_send.len()); } - Ok(Async::NotReady) => { - self.query_send_buffers.insert(0, to_send); - break; - } Err(_) => { // Errors are non-fatal because they can happen for example if we lose // connection to the network. @@ -219,9 +203,10 @@ impl MdnsService { } } + // TODO: block needs to be refactored // Check for any incoming packet. - match self.socket.poll_recv_from(&mut self.recv_buffer) { - Ok(Async::Ready((len, from))) => { + match AsyncDatagram::poll_recv_from(Pin::new(&mut self.socket), cx, &mut self.recv_buffer) { + Poll::Ready(Ok((len, from))) => { match Packet::parse(&self.recv_buffer[..len]) { Ok(packet) => { if packet.header.query { @@ -230,7 +215,7 @@ impl MdnsService { .iter() .any(|q| q.qname.to_string().as_bytes() == SERVICE_NAME) { - return Async::Ready(MdnsPacket::Query(MdnsQuery { + return Poll::Ready(MdnsPacket::Query(MdnsQuery { from, query_id: packet.header.id, send_buffers: &mut self.send_buffers, @@ -241,7 +226,7 @@ impl MdnsService { .any(|q| q.qname.to_string().as_bytes() == META_QUERY_SERVICE) { // TODO: what if multiple questions, one with SERVICE_NAME and one with META_QUERY_SERVICE? - return Async::Ready(MdnsPacket::ServiceDiscovery( + return Poll::Ready(MdnsPacket::ServiceDiscovery( MdnsServiceDiscovery { from, query_id: packet.header.id, @@ -253,11 +238,11 @@ impl MdnsService { // writing of this code non-lexical lifetimes haven't been merged // yet, and I can't manage to write this code without having borrow // issues. - task::current().notify(); - return Async::NotReady; + cx.waker().wake_by_ref(); + return Poll::Pending; } } else { - return Async::Ready(MdnsPacket::Response(MdnsResponse { + return Poll::Ready(MdnsPacket::Response(MdnsResponse { packet, from, })); @@ -269,19 +254,17 @@ impl MdnsService { // Note that ideally we would use a loop instead. However as of the writing // of this code non-lexical lifetimes haven't been merged yet, and I can't // manage to write this code without having borrow issues. - task::current().notify(); - return Async::NotReady; + cx.waker().wake_by_ref(); + return Poll::Pending; } } } - Ok(Async::NotReady) => (), - Err(_) => { + Poll::Pending => (), + Poll::Ready(Err(_)) => { // Error are non-fatal and can happen if we get disconnected from example. // The query interval will wake up the task at some point so that we can try again. } }; - - Async::NotReady } } @@ -537,20 +520,20 @@ impl<'a> fmt::Debug for MdnsPeer<'a> { #[cfg(test)] mod tests { + use futures::prelude::*; use libp2p_core::PeerId; - use std::{io, time::Duration}; - use tokio::{self, prelude::*}; + use std::{io, task::Poll, time::Duration}; use crate::service::{MdnsPacket, MdnsService}; #[test] fn discover_ourselves() { let mut service = MdnsService::new().unwrap(); let peer_id = PeerId::random(); - let stream = stream::poll_fn(move || -> Poll, io::Error> { + let stream = stream::poll_fn(move |cx| -> Poll>> { loop { - let packet = match service.poll() { - Async::Ready(packet) => packet, - Async::NotReady => return Ok(Async::NotReady), + let packet = match service.poll(cx) { + Poll::Ready(packet) => packet, + Poll::Pending => return Poll::Pending, }; match packet { @@ -560,7 +543,7 @@ mod tests { MdnsPacket::Response(response) => { for peer in response.discovered_peers() { if peer.id() == &peer_id { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } } } @@ -569,10 +552,10 @@ mod tests { } }); - tokio::run( + futures::executor::block_on( stream .map_err(|err| panic!("{:?}", err)) - .for_each(|_| Ok(())), + .for_each(|_| future::ready(())), ); } } diff --git a/misc/rw-stream-sink/Cargo.toml b/misc/rw-stream-sink/Cargo.toml index a10be35a..b1e0edaa 100644 --- a/misc/rw-stream-sink/Cargo.toml +++ b/misc/rw-stream-sink/Cargo.toml @@ -10,6 +10,4 @@ keywords = ["networking"] categories = ["network-programming", "asynchronous"] [dependencies] -bytes = "0.4" -futures = "0.1" -tokio-io = "0.1" +futures-preview = "0.3.0-alpha.17" diff --git a/misc/rw-stream-sink/src/lib.rs b/misc/rw-stream-sink/src/lib.rs index d73cb5d6..6325f88a 100644 --- a/misc/rw-stream-sink/src/lib.rs +++ b/misc/rw-stream-sink/src/lib.rs @@ -19,7 +19,7 @@ // DEALINGS IN THE SOFTWARE. //! This crate provides the `RwStreamSink` type. It wraps around a `Stream + Sink` that produces -//! and accepts byte arrays, and implements `AsyncRead` and `AsyncWrite`. +//! and accepts byte arrays, and implements `PollRead` and `PollWrite`. //! //! Each call to `write()` will send one packet on the sink. Calls to `read()` will read from //! incoming packets. @@ -27,112 +27,93 @@ //! > **Note**: Although this crate is hosted in the libp2p repo, it is purely a utility crate and //! > not at all specific to libp2p. -use bytes::{Buf, IntoBuf}; -use futures::{Async, AsyncSink, Poll, Sink, Stream}; -use std::cmp; -use std::io::Error as IoError; -use std::io::ErrorKind as IoErrorKind; -use std::io::{Read, Write}; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::{prelude::*, io::Initializer}; +use std::{cmp, io, marker::PhantomData, pin::Pin, task::Context, task::Poll}; /// Wraps around a `Stream + Sink` whose items are buffers. Implements `AsyncRead` and `AsyncWrite`. -pub struct RwStreamSink -where - S: Stream, - S::Item: IntoBuf, -{ +/// +/// The `B` generic is the type of buffers that the `Sink` accepts. The `I` generic is the type of +/// buffer that the `Stream` generates. +pub struct RwStreamSink { inner: S, - current_item: Option<::Buf>, + current_item: Option>, } -impl RwStreamSink -where - S: Stream, - S::Item: IntoBuf, -{ +impl RwStreamSink { /// Wraps around `inner`. pub fn new(inner: S) -> RwStreamSink { RwStreamSink { inner, current_item: None } } } -impl Read for RwStreamSink +impl AsyncRead for RwStreamSink where - S: Stream, - S::Item: IntoBuf, + S: TryStream, Error = io::Error> + Unpin, { - fn read(&mut self, buf: &mut [u8]) -> Result { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { // Grab the item to copy from. - let item_to_copy = loop { + let current_item = loop { if let Some(ref mut i) = self.current_item { - if i.has_remaining() { + if !i.is_empty() { break i; } } - self.current_item = Some(match self.inner.poll()? { - Async::Ready(Some(i)) => i.into_buf(), - Async::Ready(None) => return Ok(0), // EOF - Async::NotReady => return Err(IoErrorKind::WouldBlock.into()), + self.current_item = Some(match TryStream::try_poll_next(Pin::new(&mut self.inner), cx) { + Poll::Ready(Some(Ok(i))) => i, + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)), + Poll::Ready(None) => return Poll::Ready(Ok(0)), // EOF + Poll::Pending => return Poll::Pending, }); }; // Copy it! - debug_assert!(item_to_copy.has_remaining()); - let to_copy = cmp::min(buf.len(), item_to_copy.remaining()); - item_to_copy.take(to_copy).copy_to_slice(&mut buf[..to_copy]); - Ok(to_copy) - } -} - -impl AsyncRead for RwStreamSink -where - S: Stream, - S::Item: IntoBuf, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl Write for RwStreamSink -where - S: Stream + Sink, - S::SinkItem: for<'r> From<&'r [u8]>, - S::Item: IntoBuf, -{ - fn write(&mut self, buf: &[u8]) -> Result { - let len = buf.len(); - match self.inner.start_send(buf.into())? { - AsyncSink::Ready => Ok(len), - AsyncSink::NotReady(_) => Err(IoError::new(IoErrorKind::WouldBlock, "not ready")), - } + debug_assert!(!current_item.is_empty()); + let to_copy = cmp::min(buf.len(), current_item.len()); + buf[..to_copy].copy_from_slice(¤t_item[..to_copy]); + for _ in 0..to_copy { current_item.remove(0); } + Poll::Ready(Ok(to_copy)) } - fn flush(&mut self) -> Result<(), IoError> { - match self.inner.poll_complete()? { - Async::Ready(()) => Ok(()), - Async::NotReady => Err(IoError::new(IoErrorKind::WouldBlock, "not ready")) - } + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() } } impl AsyncWrite for RwStreamSink where - S: Stream + Sink, - S::SinkItem: for<'r> From<&'r [u8]>, - S::Item: IntoBuf, + S: Stream + Sink, Error = io::Error> + Unpin, { - fn shutdown(&mut self) -> Poll<(), IoError> { - self.inner.close() + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + match Sink::poll_ready(Pin::new(&mut self.inner), cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + + let len = buf.len(); + match Sink::start_send(Pin::new(&mut self.inner), buf.into()) { + Ok(()) => Poll::Ready(Ok(len)), + Err(err) => Poll::Ready(Err(err)) + } } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_flush(Pin::new(&mut self.inner), cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_close(Pin::new(&mut self.inner), cx) + } +} + +impl Unpin for RwStreamSink { } #[cfg(test)] mod tests { - use bytes::Bytes; use crate::RwStreamSink; - use futures::{prelude::*, stream, sync::mpsc::channel}; + use futures::{prelude::*, stream, channel::mpsc::channel}; use std::io::Read; // This struct merges a stream and a sink and is quite useful for tests. diff --git a/muxers/mplex/Cargo.toml b/muxers/mplex/Cargo.toml index 18ab0735..d1b51994 100644 --- a/muxers/mplex/Cargo.toml +++ b/muxers/mplex/Cargo.toml @@ -12,14 +12,12 @@ categories = ["network-programming", "asynchronous"] [dependencies] bytes = "0.4.5" fnv = "1.0" -futures = "0.1" +futures_codec = "0.2.4" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4" parking_lot = "0.8" -tokio-codec = "0.1" -tokio-io = "0.1" -unsigned-varint = { version = "0.2.1", features = ["codec"] } +unsigned-varint = { git = "https://github.com/tomaka/unsigned-varint", branch = "futures-codec", features = ["codec"] } [dev-dependencies] libp2p-tcp = { version = "0.12.0", path = "../../transports/tcp" } -tokio = "0.1" diff --git a/muxers/mplex/src/codec.rs b/muxers/mplex/src/codec.rs index 012862ba..e04aa4c2 100644 --- a/muxers/mplex/src/codec.rs +++ b/muxers/mplex/src/codec.rs @@ -19,10 +19,10 @@ // DEALINGS IN THE SOFTWARE. use libp2p_core::Endpoint; +use futures_codec::{Decoder, Encoder}; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::mem; use bytes::{BufMut, Bytes, BytesMut}; -use tokio_io::codec::{Decoder, Encoder}; use unsigned_varint::{codec, encode}; // Maximum size for a packet: 1MB as per the spec. diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index 8806b031..36ccc747 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -20,9 +20,10 @@ mod codec; -use std::{cmp, iter, mem}; +use std::{cmp, iter, mem, pin::Pin, task::Context, task::Poll}; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc}; +use std::sync::Arc; +use std::task::Waker; use bytes::Bytes; use libp2p_core::{ Endpoint, @@ -31,10 +32,10 @@ use libp2p_core::{ }; use log::{debug, trace}; use parking_lot::Mutex; -use fnv::{FnvHashMap, FnvHashSet}; -use futures::{prelude::*, executor, future, stream::Fuse, task, task_local, try_ready}; -use tokio_codec::Framed; -use tokio_io::{AsyncRead, AsyncWrite}; +use fnv::FnvHashSet; +use futures::{prelude::*, future, io::Initializer, ready, stream::Fuse}; +use futures::task::{ArcWake, waker_ref}; +use futures_codec::Framed; /// Configuration for the multiplexer. #[derive(Debug, Clone)] @@ -96,22 +97,22 @@ impl MplexConfig { #[inline] fn upgrade(self, i: C) -> Multiplex where - C: AsyncRead + AsyncWrite + C: AsyncRead + AsyncWrite + Unpin { let max_buffer_len = self.max_buffer_len; Multiplex { inner: Mutex::new(MultiplexInner { error: Ok(()), - inner: executor::spawn(Framed::new(i, codec::Codec::new()).fuse()), + inner: Framed::new(i, codec::Codec::new()).fuse(), config: self, buffer: Vec::with_capacity(cmp::min(max_buffer_len, 512)), opened_substreams: Default::default(), next_outbound_stream_id: 0, notifier_read: Arc::new(Notifier { - to_notify: Mutex::new(Default::default()), + to_wake: Mutex::new(Default::default()), }), notifier_write: Arc::new(Notifier { - to_notify: Mutex::new(Default::default()), + to_wake: Mutex::new(Default::default()), }), is_shutdown: false, is_acknowledged: false, @@ -156,27 +157,27 @@ impl UpgradeInfo for MplexConfig { impl InboundUpgrade for MplexConfig where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, { type Output = Multiplex>; type Error = IoError; - type Future = future::FutureResult; + type Future = future::Ready>; fn upgrade_inbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { - future::ok(self.upgrade(socket)) + future::ready(Ok(self.upgrade(socket))) } } impl OutboundUpgrade for MplexConfig where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, { type Output = Multiplex>; type Error = IoError; - type Future = future::FutureResult; + type Future = future::Ready>; fn upgrade_outbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { - future::ok(self.upgrade(socket)) + future::ready(Ok(self.upgrade(socket))) } } @@ -190,7 +191,7 @@ struct MultiplexInner { // Error that happened earlier. Should poison any attempt to use this `MultiplexError`. error: Result<(), IoError>, // Underlying stream. - inner: executor::Spawn>>, + inner: Fuse>, /// The original configuration. config: MplexConfig, // Buffer of elements pulled from the stream but not processed yet. @@ -202,9 +203,9 @@ struct MultiplexInner { opened_substreams: FnvHashSet<(u32, Endpoint)>, // Id of the next outgoing substream. next_outbound_stream_id: u32, - /// List of tasks to notify when a read event happens on the underlying stream. + /// List of wakers to wake when a read event happens on the underlying stream. notifier_read: Arc, - /// List of tasks to notify when a write event happens on the underlying stream. + /// List of wakers to wake when a write event happens on the underlying stream. notifier_write: Arc, /// If true, the connection has been shut down. We need to be careful not to accidentally /// call `Sink::poll_complete` or `Sink::start_send` after `Sink::close`. @@ -214,23 +215,26 @@ struct MultiplexInner { } struct Notifier { - /// List of tasks to notify. - to_notify: Mutex>, + /// List of wakers to wake. + to_wake: Mutex>, } -impl executor::Notify for Notifier { - fn notify(&self, _: usize) { - let tasks = mem::replace(&mut *self.to_notify.lock(), Default::default()); - for (_, task) in tasks { - task.notify(); +impl Notifier { + fn insert(&self, waker: &Waker) { + let mut to_wake = self.to_wake.lock(); + if to_wake.iter().all(|w| !w.will_wake(waker)) { + to_wake.push(waker.clone()); } } } -// TODO: replace with another system -static NEXT_TASK_ID: AtomicUsize = AtomicUsize::new(0); -task_local!{ - static TASK_ID: usize = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed) +impl ArcWake for Notifier { + fn wake_by_ref(arc_self: &Arc) { + let wakers = mem::replace(&mut *arc_self.to_wake.lock(), Default::default()); + for waker in wakers { + waker.wake(); + } + } } // Note [StreamId]: mplex no longer partitions stream IDs into odd (for initiators) and @@ -245,25 +249,27 @@ task_local!{ /// Processes elements in `inner` until one matching `filter` is found. /// -/// If `NotReady` is returned, the current task is scheduled for later, just like with any `Poll`. -/// `Ready(Some())` is almost always returned. An error is returned if the stream is EOF. -fn next_match(inner: &mut MultiplexInner, mut filter: F) -> Poll -where C: AsyncRead + AsyncWrite, +/// If `Pending` is returned, the waker is kept and notifier later, just like with any `Poll`. +/// `Ready(Ok())` is almost always returned. An error is returned if the stream is EOF. +fn next_match(inner: &mut MultiplexInner, cx: &mut Context, mut filter: F) -> Poll> +where C: AsyncRead + AsyncWrite + Unpin, F: FnMut(&codec::Elem) -> Option, { // If an error happened earlier, immediately return it. if let Err(ref err) = inner.error { - return Err(IoError::new(err.kind(), err.to_string())); + return Poll::Ready(Err(IoError::new(err.kind(), err.to_string()))); } if let Some((offset, out)) = inner.buffer.iter().enumerate().filter_map(|(n, v)| filter(v).map(|v| (n, v))).next() { + // Found a matching entry in the existing buffer! + // The buffer was full and no longer is, so let's notify everything. if inner.buffer.len() == inner.config.max_buffer_len { - executor::Notify::notify(&*inner.notifier_read, 0); + ArcWake::wake_by_ref(&inner.notifier_read); } inner.buffer.remove(offset); - return Ok(Async::Ready(out)); + return Poll::Ready(Ok(out)); } loop { @@ -274,24 +280,24 @@ where C: AsyncRead + AsyncWrite, match inner.config.max_buffer_behaviour { MaxBufferBehaviour::CloseAll => { inner.error = Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length")); - return Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length")); + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length"))); }, MaxBufferBehaviour::Block => { - inner.notifier_read.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - return Ok(Async::NotReady); + inner.notifier_read.insert(cx.waker()); + return Poll::Pending }, } } - inner.notifier_read.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - let elem = match inner.inner.poll_stream_notify(&inner.notifier_read, 0) { - Ok(Async::Ready(Some(item))) => item, - Ok(Async::Ready(None)) => return Err(IoErrorKind::BrokenPipe.into()), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { + inner.notifier_read.insert(cx.waker()); + let elem = match Stream::poll_next(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_read))) { + Poll::Ready(Some(Ok(item))) => item, + Poll::Ready(None) => return Poll::Ready(Err(IoErrorKind::BrokenPipe.into())), + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Err(err))) => { let err2 = IoError::new(err.kind(), err.to_string()); inner.error = Err(err); - return Err(err2); + return Poll::Ready(Err(err2)); }, }; @@ -312,7 +318,7 @@ where C: AsyncRead + AsyncWrite, } if let Some(out) = filter(&elem) { - return Ok(Async::Ready(out)); + return Poll::Ready(Ok(out)); } else { let endpoint = elem.endpoint().unwrap_or(Endpoint::Dialer); if inner.opened_substreams.contains(&(elem.substream_id(), !endpoint)) || elem.is_open_msg() { @@ -325,45 +331,57 @@ where C: AsyncRead + AsyncWrite, } // Small convenience function that tries to write `elem` to the stream. -fn poll_send(inner: &mut MultiplexInner, elem: codec::Elem) -> Poll<(), IoError> -where C: AsyncRead + AsyncWrite +fn poll_send(inner: &mut MultiplexInner, cx: &mut Context, elem: codec::Elem) -> Poll> +where C: AsyncRead + AsyncWrite + Unpin { if inner.is_shutdown { - return Err(IoError::new(IoErrorKind::Other, "connection is shut down")) + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "connection is shut down"))) } - inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - match inner.inner.start_send_notify(elem, &inner.notifier_write, 0) { - Ok(AsyncSink::Ready) => Ok(Async::Ready(())), - Ok(AsyncSink::NotReady(_)) => Ok(Async::NotReady), - Err(err) => Err(err) + + inner.notifier_write.insert(cx.waker()); + + match Sink::poll_ready(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) { + Poll::Ready(Ok(())) => { + match Sink::start_send(Pin::new(&mut inner.inner), elem) { + Ok(()) => Poll::Ready(Ok(())), + Err(err) => Poll::Ready(Err(err)) + } + }, + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)) } } impl StreamMuxer for Multiplex -where C: AsyncRead + AsyncWrite +where C: AsyncRead + AsyncWrite + Unpin { type Substream = Substream; type OutboundSubstream = OutboundSubstream; type Error = IoError; - fn poll_inbound(&self) -> Poll { + fn poll_inbound(&self, cx: &mut Context) -> Poll> { let mut inner = self.inner.lock(); if inner.opened_substreams.len() >= inner.config.max_substreams { debug!("Refused substream; reached maximum number of substreams {}", inner.config.max_substreams); - return Err(IoError::new(IoErrorKind::ConnectionRefused, - "exceeded maximum number of open substreams")); + return Poll::Ready(Err(IoError::new(IoErrorKind::ConnectionRefused, + "exceeded maximum number of open substreams"))); } - let num = try_ready!(next_match(&mut inner, |elem| { + let num = ready!(next_match(&mut inner, cx, |elem| { match elem { codec::Elem::Open { substream_id } => Some(*substream_id), _ => None, } })); + let num = match num { + Ok(n) => n, + Err(err) => return Poll::Ready(Err(err)), + }; + debug!("Successfully opened inbound substream {}", num); - Ok(Async::Ready(Substream { + Poll::Ready(Ok(Substream { current_data: Bytes::new(), num, endpoint: Endpoint::Listener, @@ -391,21 +409,21 @@ where C: AsyncRead + AsyncWrite } } - fn poll_outbound(&self, substream: &mut Self::OutboundSubstream) -> Poll { + fn poll_outbound(&self, cx: &mut Context, substream: &mut Self::OutboundSubstream) -> Poll> { loop { let mut inner = self.inner.lock(); let polling = match substream.state { OutboundSubstreamState::SendElem(ref elem) => { - poll_send(&mut inner, elem.clone()) + poll_send(&mut inner, cx, elem.clone()) }, OutboundSubstreamState::Flush => { if inner.is_shutdown { - return Err(IoError::new(IoErrorKind::Other, "connection is shut down")) + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "connection is shut down"))) } let inner = &mut *inner; // Avoids borrow errors - inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - inner.inner.poll_flush_notify(&inner.notifier_write, 0) + inner.notifier_write.insert(cx.waker()); + Sink::poll_flush(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) }, OutboundSubstreamState::Done => { panic!("Polling outbound substream after it's been succesfully open"); @@ -413,16 +431,14 @@ where C: AsyncRead + AsyncWrite }; match polling { - Ok(Async::Ready(())) => (), - Ok(Async::NotReady) => { - return Ok(Async::NotReady) - }, - Err(err) => { + Poll::Ready(Ok(())) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => { debug!("Failed to open outbound substream {}", substream.num); inner.buffer.retain(|elem| { elem.substream_id() != substream.num || elem.endpoint() == Some(Endpoint::Dialer) }); - return Err(err) + return Poll::Ready(Err(err)); }, }; @@ -436,7 +452,7 @@ where C: AsyncRead + AsyncWrite OutboundSubstreamState::Flush => { debug!("Successfully opened outbound substream {}", substream.num); substream.state = OutboundSubstreamState::Done; - return Ok(Async::Ready(Substream { + return Poll::Ready(Ok(Substream { num: substream.num, current_data: Bytes::new(), endpoint: Endpoint::Dialer, @@ -454,27 +470,27 @@ where C: AsyncRead + AsyncWrite // Nothing to do. } - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() } - fn read_substream(&self, substream: &mut Self::Substream, buf: &mut [u8]) -> Poll { + fn read_substream(&self, cx: &mut Context, substream: &mut Self::Substream, buf: &mut [u8]) -> Poll> { loop { // First, transfer from `current_data`. if !substream.current_data.is_empty() { let len = cmp::min(substream.current_data.len(), buf.len()); buf[..len].copy_from_slice(&substream.current_data.split_to(len)); - return Ok(Async::Ready(len)); + return Poll::Ready(Ok(len)); } // If the remote writing side is closed, return EOF. if !substream.remote_open { - return Ok(Async::Ready(0)); + return Poll::Ready(Ok(0)); } // Try to find a packet of data in the buffer. let mut inner = self.inner.lock(); - let next_data_poll = next_match(&mut inner, |elem| { + let next_data_poll = next_match(&mut inner, cx, |elem| { match elem { codec::Elem::Data { substream_id, endpoint, data, .. } if *substream_id == substream.num && *endpoint != substream.endpoint => // see note [StreamId] @@ -492,28 +508,29 @@ where C: AsyncRead + AsyncWrite // We're in a loop, so all we need to do is set `substream.current_data` to the data we // just read and wait for the next iteration. - match next_data_poll? { - Async::Ready(Some(data)) => substream.current_data = data, - Async::Ready(None) => { + match next_data_poll { + Poll::Ready(Ok(Some(data))) => substream.current_data = data, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Ok(None)) => { substream.remote_open = false; - return Ok(Async::Ready(0)); + return Poll::Ready(Ok(0)); }, - Async::NotReady => { + Poll::Pending => { // There was no data packet in the buffer about this substream; maybe it's // because it has been closed. if inner.opened_substreams.contains(&(substream.num, substream.endpoint)) { - return Ok(Async::NotReady) + return Poll::Pending } else { - return Ok(Async::Ready(0)) + return Poll::Ready(Ok(0)) } }, } } } - fn write_substream(&self, substream: &mut Self::Substream, buf: &[u8]) -> Poll { + fn write_substream(&self, cx: &mut Context, substream: &mut Self::Substream, buf: &[u8]) -> Poll> { if !substream.local_open { - return Err(IoErrorKind::BrokenPipe.into()); + return Poll::Ready(Err(IoErrorKind::BrokenPipe.into())); } let mut inner = self.inner.lock(); @@ -526,26 +543,27 @@ where C: AsyncRead + AsyncWrite endpoint: substream.endpoint, }; - match poll_send(&mut inner, elem)? { - Async::Ready(()) => Ok(Async::Ready(to_write)), - Async::NotReady => Ok(Async::NotReady) + match poll_send(&mut inner, cx, elem) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(to_write)), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, } } - fn flush_substream(&self, _substream: &mut Self::Substream) -> Poll<(), IoError> { + fn flush_substream(&self, cx: &mut Context, _substream: &mut Self::Substream) -> Poll> { let mut inner = self.inner.lock(); if inner.is_shutdown { - return Err(IoError::new(IoErrorKind::Other, "connection is shut down")) + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "connection is shut down"))) } let inner = &mut *inner; // Avoids borrow errors - inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - inner.inner.poll_flush_notify(&inner.notifier_write, 0) + inner.notifier_write.insert(cx.waker()); + Sink::poll_flush(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) } - fn shutdown_substream(&self, sub: &mut Self::Substream) -> Poll<(), IoError> { + fn shutdown_substream(&self, cx: &mut Context, sub: &mut Self::Substream) -> Poll> { if !sub.local_open { - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); } let elem = codec::Elem::Close { @@ -554,8 +572,8 @@ where C: AsyncRead + AsyncWrite }; let mut inner = self.inner.lock(); - let result = poll_send(&mut inner, elem); - if let Ok(Async::Ready(())) = result { + let result = poll_send(&mut inner, cx, elem); + if let Poll::Ready(Ok(())) = result { sub.local_open = false; } result @@ -572,22 +590,27 @@ where C: AsyncRead + AsyncWrite } #[inline] - fn close(&self) -> Poll<(), IoError> { + fn close(&self, cx: &mut Context) -> Poll> { let inner = &mut *self.inner.lock(); - inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - try_ready!(inner.inner.close_notify(&inner.notifier_write, 0)); - inner.is_shutdown = true; - Ok(Async::Ready(())) + inner.notifier_write.insert(cx.waker()); + match Sink::poll_close(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) { + Poll::Ready(Ok(())) => { + inner.is_shutdown = true; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } } #[inline] - fn flush_all(&self) -> Poll<(), IoError> { + fn flush_all(&self, cx: &mut Context) -> Poll> { let inner = &mut *self.inner.lock(); if inner.is_shutdown { - return Ok(Async::Ready(())) + return Poll::Ready(Ok(())) } - inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); - inner.inner.poll_flush_notify(&inner.notifier_write, 0) + inner.notifier_write.insert(cx.waker()); + Sink::poll_flush(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) } } diff --git a/protocols/deflate/Cargo.toml b/protocols/deflate/Cargo.toml index 035a7394..5c723f73 100644 --- a/protocols/deflate/Cargo.toml +++ b/protocols/deflate/Cargo.toml @@ -10,14 +10,13 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] -futures = "0.1" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../../core" } -tokio-io = "0.1.12" -flate2 = { version = "1.0", features = ["tokio"] } +flate2 = "1.0" [dev-dependencies] +async-std = "0.99" env_logger = "0.6" libp2p-tcp = { version = "0.12.0", path = "../../transports/tcp" } +rand = "0.7" quickcheck = "0.8" -tokio = "0.1" -log = "0.4" diff --git a/protocols/deflate/src/lib.rs b/protocols/deflate/src/lib.rs index 7dbf03eb..74f33c69 100644 --- a/protocols/deflate/src/lib.rs +++ b/protocols/deflate/src/lib.rs @@ -18,21 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use flate2::read::DeflateDecoder; -use flate2::write::DeflateEncoder; -use flate2::Compression; -use std::io; - -use futures::future::{self, FutureResult}; -use libp2p_core::{upgrade::Negotiated, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; -use std::iter; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::{prelude::*, ready}; +use libp2p_core::{Negotiated, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; +use std::{io, iter, pin::Pin, task::Context, task::Poll}; #[derive(Debug, Copy, Clone)] -pub struct DeflateConfig; +pub struct DeflateConfig { + compression: flate2::Compression, +} -/// Output of the deflate protocol. -pub type DeflateOutput = DeflateDecoder>; +impl Default for DeflateConfig { + fn default() -> Self { + DeflateConfig { + compression: flate2::Compression::fast(), + } + } +} impl UpgradeInfo for DeflateConfig { type Info = &'static [u8]; @@ -49,13 +50,10 @@ where { type Output = DeflateOutput>; type Error = io::Error; - type Future = FutureResult; + type Future = future::Ready>; fn upgrade_inbound(self, r: Negotiated, _: Self::Info) -> Self::Future { - future::ok(DeflateDecoder::new(DeflateEncoder::new( - r, - Compression::default(), - ))) + future::ok(DeflateOutput::new(r, self.compression)) } } @@ -65,12 +63,195 @@ where { type Output = DeflateOutput>; type Error = io::Error; - type Future = FutureResult; + type Future = future::Ready>; fn upgrade_outbound(self, w: Negotiated, _: Self::Info) -> Self::Future { - future::ok(DeflateDecoder::new(DeflateEncoder::new( - w, - Compression::default(), - ))) + future::ok(DeflateOutput::new(w, self.compression)) + } +} + +/// Decodes and encodes traffic using DEFLATE. +pub struct DeflateOutput { + /// Inner stream where we read compressed data from and write compressed data to. + inner: S, + /// Internal object used to hold the state of the compression. + compress: flate2::Compress, + /// Internal object used to hold the state of the decompression. + decompress: flate2::Decompress, + /// Temporary buffer between `compress` and `inner`. Stores compressed bytes that need to be + /// sent out once `inner` is ready to accept more. + write_out: Vec, + /// Temporary buffer between `decompress` and `inner`. Stores compressed bytes that need to be + /// given to `decompress`. + read_interm: Vec, + /// When we read from `inner` and `Ok(0)` is returned, we set this to `true` so that we don't + /// read from it again. + inner_read_eof: bool, +} + +impl DeflateOutput { + fn new(inner: S, compression: flate2::Compression) -> Self { + DeflateOutput { + inner, + compress: flate2::Compress::new(compression, false), + decompress: flate2::Decompress::new(false), + write_out: Vec::with_capacity(256), + read_interm: Vec::with_capacity(256), + inner_read_eof: false, + } + } + + /// Tries to write the content of `self.write_out` to `self.inner`. + /// Returns `Ready(Ok(()))` if `self.write_out` is empty. + fn flush_write_out(&mut self, cx: &mut Context) -> Poll> + where S: AsyncWrite + Unpin + { + loop { + if self.write_out.is_empty() { + return Poll::Ready(Ok(())) + } + + match AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &self.write_out) { + Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Poll::Ready(Ok(n)) => self.write_out = self.write_out.split_off(n), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + } + } +} + +impl AsyncRead for DeflateOutput + where S: AsyncRead + Unpin +{ + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + // We use a `this` variable because the compiler doesn't allow multiple mutable borrows + // across a `Deref`. + let this = &mut *self; + + loop { + // Read from `self.inner` into `self.read_interm` if necessary. + if this.read_interm.is_empty() && !this.inner_read_eof { + unsafe { + this.read_interm.reserve(256); + this.read_interm.set_len(this.read_interm.capacity()); + this.inner.initializer().initialize(&mut this.read_interm); + } + + match AsyncRead::poll_read(Pin::new(&mut this.inner), cx, &mut this.read_interm) { + Poll::Ready(Ok(0)) => { + this.inner_read_eof = true; + this.read_interm.clear(); + } + Poll::Ready(Ok(n)) => { + this.read_interm.truncate(n) + }, + Poll::Ready(Err(err)) => { + this.read_interm.clear(); + return Poll::Ready(Err(err)) + }, + Poll::Pending => { + this.read_interm.clear(); + return Poll::Pending + }, + } + } + debug_assert!(!this.read_interm.is_empty() || this.inner_read_eof); + + let before_out = this.decompress.total_out(); + let before_in = this.decompress.total_in(); + let ret = this.decompress.decompress(&this.read_interm, buf, if this.inner_read_eof { flate2::FlushDecompress::Finish } else { flate2::FlushDecompress::None })?; + + // Remove from `self.read_interm` the bytes consumed by the decompressor. + let consumed = (this.decompress.total_in() - before_in) as usize; + this.read_interm = this.read_interm.split_off(consumed); + + let read = (this.decompress.total_out() - before_out) as usize; + if read != 0 || ret == flate2::Status::StreamEnd { + return Poll::Ready(Ok(read)) + } + } + } + + unsafe fn initializer(&self) -> futures::io::Initializer { + futures::io::Initializer::nop() + } +} + +impl AsyncWrite for DeflateOutput + where S: AsyncWrite + Unpin +{ + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) + -> Poll> + { + // We use a `this` variable because the compiler doesn't allow multiple mutable borrows + // across a `Deref`. + let this = &mut *self; + + // We don't want to accumulate too much data in `self.write_out`, so we only proceed if it + // is empty. + ready!(this.flush_write_out(cx))?; + + // We special-case this, otherwise an empty buffer would make the loop below infinite. + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + // Unfortunately, the compressor might be in a "flushing mode", not accepting any input + // data. We don't want to return `Ok(0)` in that situation, as that would be wrong. + // Instead, we invoke the compressor in a loop until it accepts some of our data. + loop { + let before_in = this.compress.total_in(); + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + let ret = this.compress.compress_vec(buf, &mut this.write_out, flate2::FlushCompress::None)?; + let written = (this.compress.total_in() - before_in) as usize; + + if written != 0 || ret == flate2::Status::StreamEnd { + return Poll::Ready(Ok(written)); + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // We use a `this` variable because the compiler doesn't allow multiple mutable borrows + // across a `Deref`. + let this = &mut *self; + + ready!(this.flush_write_out(cx))?; + this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Sync)?; + + loop { + ready!(this.flush_write_out(cx))?; + + debug_assert!(this.write_out.is_empty()); + // We ask the compressor to flush everything into `self.write_out`. + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::None)?; + if this.write_out.is_empty() { + break; + } + } + + AsyncWrite::poll_flush(Pin::new(&mut this.inner), cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // We use a `this` variable because the compiler doesn't allow multiple mutable borrows + // across a `Deref`. + let this = &mut *self; + + loop { + ready!(this.flush_write_out(cx))?; + + // We ask the compressor to flush everything into `self.write_out`. + debug_assert!(this.write_out.is_empty()); + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Finish)?; + if this.write_out.is_empty() { + break; + } + } + + AsyncWrite::poll_close(Pin::new(&mut this.inner), cx) } } diff --git a/protocols/deflate/tests/test.rs b/protocols/deflate/tests/test.rs index a0b2c07f..28a0c1fd 100644 --- a/protocols/deflate/tests/test.rs +++ b/protocols/deflate/tests/test.rs @@ -18,23 +18,16 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::prelude::*; -use libp2p_core::transport::{ListenerEvent, Transport}; -use libp2p_core::upgrade::{self, Negotiated}; -use libp2p_deflate::{DeflateConfig, DeflateOutput}; -use libp2p_tcp::{TcpConfig, TcpTransStream}; -use log::info; +use futures::{prelude::*, channel::oneshot}; +use libp2p_core::{transport::Transport, upgrade}; +use libp2p_deflate::DeflateConfig; +use libp2p_tcp::TcpConfig; use quickcheck::QuickCheck; -use tokio::{self, io}; #[test] fn deflate() { - let _ = env_logger::try_init(); - fn prop(message: Vec) -> bool { - let client = TcpConfig::new().and_then(|c, e| upgrade::apply(c, DeflateConfig {}, e)); - let server = client.clone(); - run(server, client, message); + run(message); true } @@ -43,56 +36,40 @@ fn deflate() { .quickcheck(prop as fn(Vec) -> bool) } -type Output = DeflateOutput>; - -fn run(server_transport: T, client_transport: T, message1: Vec) -where - T: Transport, - T::Dial: Send + 'static, - T::Listener: Send + 'static, - T::ListenerUpgrade: Send + 'static, -{ - let message2 = message1.clone(); - - let mut server = server_transport - .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) - .unwrap(); - let server_address = server - .by_ref() - .wait() - .next() - .expect("some event") - .expect("no error") - .into_new_address() - .expect("listen address"); - let server = server - .take(1) - .filter_map(ListenerEvent::into_upgrade) - .and_then(|(client, _)| client) - .map_err(|e| panic!("server error: {}", e)) - .and_then(|client| { - info!("server: reading message"); - io::read_to_end(client, Vec::new()) - }) - .for_each(move |(_, msg)| { - info!("server: read message: {:?}", msg); - assert_eq!(msg, message1); - Ok(()) - }); - - let client = client_transport - .dial(server_address.clone()) - .unwrap() - .map_err(|e| panic!("client error: {}", e)) - .and_then(move |server| { - io::write_all(server, message2).and_then(|(client, _)| io::shutdown(client)) - }) - .map(|_| ()); - - let future = client - .join(server) - .map_err(|e| panic!("{:?}", e)) - .map(|_| ()); - - tokio::run(future) +#[test] +fn lot_of_data() { + run((0..16*1024*1024).map(|_| rand::random::()).collect()); +} + +fn run(message1: Vec) { + let transport1 = TcpConfig::new().and_then(|c, e| upgrade::apply(c, DeflateConfig::default(), e)); + let transport2 = transport1.clone(); + let message2 = message1.clone(); + let (l_a_tx, l_a_rx) = oneshot::channel(); + + async_std::task::spawn(async move { + let mut server = transport1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + let server_address = server.next().await.unwrap().unwrap().into_new_address().unwrap(); + l_a_tx.send(server_address).unwrap(); + + let mut connec = server.next().await.unwrap().unwrap().into_upgrade().unwrap().0.await.unwrap(); + + let mut buf = vec![0; message2.len()]; + connec.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf[..], &message2[..]); + + connec.write_all(&message2).await.unwrap(); + connec.close().await.unwrap(); + }); + + futures::executor::block_on(async move { + let listen_addr = l_a_rx.await.unwrap(); + let mut connec = transport2.dial(listen_addr).unwrap().await.unwrap(); + connec.write_all(&message1).await.unwrap(); + connec.close().await.unwrap(); + + let mut buf = Vec::new(); + connec.read_to_end(&mut buf).await.unwrap(); + assert_eq!(&buf[..], &message1[..]); + }); } diff --git a/protocols/floodsub/Cargo.toml b/protocols/floodsub/Cargo.toml index 87ba4ab0..1ca88bd0 100644 --- a/protocols/floodsub/Cargo.toml +++ b/protocols/floodsub/Cargo.toml @@ -14,10 +14,9 @@ bs58 = "0.2.0" bytes = "0.4" cuckoofilter = "0.3.2" fnv = "1.0" -futures = "0.1" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../../core" } libp2p-swarm = { version = "0.2.0", path = "../../swarm" } protobuf = "2.3" rand = "0.6" smallvec = "0.6.5" -tokio-io = "0.1" diff --git a/protocols/floodsub/src/layer.rs b/protocols/floodsub/src/layer.rs index ba46dfdf..3d7a0c0e 100644 --- a/protocols/floodsub/src/layer.rs +++ b/protocols/floodsub/src/layer.rs @@ -35,7 +35,7 @@ use rand; use smallvec::SmallVec; use std::{collections::VecDeque, iter, marker::PhantomData}; use std::collections::hash_map::{DefaultHasher, HashMap}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::task::{Context, Poll}; /// Network behaviour that automatically identifies nodes periodically, and returns information /// about them. @@ -230,7 +230,7 @@ impl Floodsub { impl NetworkBehaviour for Floodsub where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin, { type ProtocolsHandler = OneShotHandler; type OutEvent = FloodsubEvent; @@ -359,18 +359,19 @@ where fn poll( &mut self, + _: &mut Context, _: &mut impl PollParameters, - ) -> Async< + ) -> Poll< NetworkBehaviourAction< ::InEvent, Self::OutEvent, >, > { if let Some(event) = self.events.pop_front() { - return Async::Ready(event); + return Poll::Ready(event); } - Async::NotReady + Poll::Pending } } diff --git a/protocols/floodsub/src/protocol.rs b/protocols/floodsub/src/protocol.rs index e6951321..882c86d1 100644 --- a/protocols/floodsub/src/protocol.rs +++ b/protocols/floodsub/src/protocol.rs @@ -20,10 +20,10 @@ use crate::rpc_proto; use crate::topic::TopicHash; +use futures::prelude::*; use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, PeerId, upgrade}; use protobuf::{ProtobufError, Message as ProtobufMessage}; use std::{error, fmt, io, iter}; -use tokio_io::{AsyncRead, AsyncWrite}; /// Implementation of `ConnectionUpgrade` for the floodsub protocol. #[derive(Debug, Clone, Default)] @@ -49,7 +49,7 @@ impl UpgradeInfo for FloodsubConfig { impl InboundUpgrade for FloodsubConfig where - TSocket: AsyncRead + AsyncWrite, + TSocket: AsyncRead + AsyncWrite + Unpin, { type Output = FloodsubRpc; type Error = FloodsubDecodeError; @@ -164,7 +164,7 @@ impl UpgradeInfo for FloodsubRpc { impl OutboundUpgrade for FloodsubRpc where - TSocket: AsyncWrite + AsyncRead, + TSocket: AsyncWrite + AsyncRead + Unpin, { type Output = (); type Error = io::Error; diff --git a/protocols/identify/Cargo.toml b/protocols/identify/Cargo.toml index 8292c4a1..21f628ed 100644 --- a/protocols/identify/Cargo.toml +++ b/protocols/identify/Cargo.toml @@ -11,17 +11,16 @@ categories = ["network-programming", "asynchronous"] [dependencies] bytes = "0.4" -futures = "0.1" +futures_codec = "0.2" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../../core" } libp2p-swarm = { version = "0.2.0", path = "../../swarm" } log = "0.4.1" multiaddr = { package = "parity-multiaddr", version = "0.5.0", path = "../../misc/multiaddr" } protobuf = "2.3" smallvec = "0.6" -tokio-codec = "0.1" -tokio-io = "0.1.0" -wasm-timer = "0.1" -unsigned-varint = { version = "0.2.1", features = ["codec"] } +wasm-timer = "0.2" +unsigned-varint = { git = "https://github.com/tomaka/unsigned-varint", branch = "futures-codec", features = ["codec"] } void = "1.0" [dev-dependencies] @@ -29,4 +28,3 @@ libp2p-mplex = { version = "0.12.0", path = "../../muxers/mplex" } libp2p-secio = { version = "0.12.0", path = "../../protocols/secio" } libp2p-tcp = { version = "0.12.0", path = "../../transports/tcp" } rand = "0.6" -tokio = "0.1" diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index 8e984bc7..90eb056d 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -23,6 +23,7 @@ use futures::prelude::*; use libp2p_core::upgrade::{ InboundUpgrade, OutboundUpgrade, + ReadOneError, Negotiated }; use libp2p_swarm::{ @@ -33,9 +34,8 @@ use libp2p_swarm::{ ProtocolsHandlerUpgrErr }; use smallvec::SmallVec; -use std::{io, marker::PhantomData, time::Duration}; -use tokio_io::{AsyncRead, AsyncWrite}; -use wasm_timer::{Delay, Instant}; +use std::{marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration}; +use wasm_timer::Delay; use void::Void; /// Delay between the moment we connect and the first time we identify. @@ -75,7 +75,7 @@ pub enum IdentifyHandlerEvent { /// We received a request for identification. Identify(ReplySubstream>), /// Failed to identify the remote. - IdentificationError(ProtocolsHandlerUpgrErr), + IdentificationError(ProtocolsHandlerUpgrErr), } impl IdentifyHandler { @@ -84,7 +84,7 @@ impl IdentifyHandler { IdentifyHandler { config: IdentifyProtocolConfig, events: SmallVec::new(), - next_id: Delay::new(Instant::now() + DELAY_TO_FIRST_ID), + next_id: Delay::new(DELAY_TO_FIRST_ID), keep_alive: KeepAlive::Yes, marker: PhantomData, } @@ -93,11 +93,11 @@ impl IdentifyHandler { impl ProtocolsHandler for IdentifyHandler where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin + 'static, { type InEvent = Void; type OutEvent = IdentifyHandlerEvent; - type Error = wasm_timer::Error; + type Error = ReadOneError; type Substream = TSubstream; type InboundProtocol = IdentifyProtocolConfig; type OutboundProtocol = IdentifyProtocolConfig; @@ -134,38 +134,39 @@ where ) { self.events.push(IdentifyHandlerEvent::IdentificationError(err)); self.keep_alive = KeepAlive::No; - self.next_id.reset(Instant::now() + TRY_AGAIN_ON_ERR); + self.next_id.reset(TRY_AGAIN_ON_ERR); } fn connection_keep_alive(&self) -> KeepAlive { self.keep_alive } - fn poll(&mut self) -> Poll< + fn poll(&mut self, cx: &mut Context) -> Poll< ProtocolsHandlerEvent< Self::OutboundProtocol, Self::OutboundOpenInfo, IdentifyHandlerEvent, + Self::Error, >, - Self::Error, > { if !self.events.is_empty() { - return Ok(Async::Ready(ProtocolsHandlerEvent::Custom( + return Poll::Ready(ProtocolsHandlerEvent::Custom( self.events.remove(0), - ))); + )); } // Poll the future that fires when we need to identify the node again. - match self.next_id.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(()) => { - self.next_id.reset(Instant::now() + DELAY_TO_NEXT_ID); + match Future::poll(Pin::new(&mut self.next_id), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + self.next_id.reset(DELAY_TO_NEXT_ID); let ev = ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new(self.config.clone()), info: (), }; - Ok(Async::Ready(ev)) + Poll::Ready(ev) } + Poll::Ready(Err(err)) => Poll::Ready(ProtocolsHandlerEvent::Close(err.into())) } } } diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 7c8b68e4..c28746c8 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -19,14 +19,14 @@ // DEALINGS IN THE SOFTWARE. use crate::handler::{IdentifyHandler, IdentifyHandlerEvent}; -use crate::protocol::{IdentifyInfo, ReplySubstream, ReplyFuture}; +use crate::protocol::{IdentifyInfo, ReplySubstream}; use futures::prelude::*; use libp2p_core::{ ConnectedPoint, Multiaddr, PeerId, PublicKey, - upgrade::{Negotiated, UpgradeError} + upgrade::{Negotiated, ReadOneError, UpgradeError} }; use libp2p_swarm::{ NetworkBehaviour, @@ -35,8 +35,7 @@ use libp2p_swarm::{ ProtocolsHandler, ProtocolsHandlerUpgrErr }; -use std::{collections::HashMap, collections::VecDeque, io}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{collections::HashMap, collections::VecDeque, io, pin::Pin, task::Context, task::Poll}; use void::Void; /// Network behaviour that automatically identifies nodes periodically, returns information @@ -67,7 +66,7 @@ enum Reply { /// The reply is being sent. Sending { peer: PeerId, - io: ReplyFuture> + io: Pin> + Send>>, } } @@ -87,7 +86,7 @@ impl Identify { impl NetworkBehaviour for Identify where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type ProtocolsHandler = IdentifyHandler; type OutEvent = IdentifyEvent; @@ -154,15 +153,16 @@ where fn poll( &mut self, + cx: &mut Context, params: &mut impl PollParameters, - ) -> Async< + ) -> Poll< NetworkBehaviourAction< ::InEvent, Self::OutEvent, >, > { if let Some(event) = self.events.pop_front() { - return Async::Ready(event); + return Poll::Ready(event); } if let Some(r) = self.pending_replies.pop_front() { @@ -189,17 +189,17 @@ where listen_addrs: listen_addrs.clone(), protocols: protocols.clone(), }; - let io = io.send(info, &observed); + let io = Box::pin(io.send(info, &observed)); reply = Some(Reply::Sending { peer, io }); } Some(Reply::Sending { peer, mut io }) => { sending += 1; - match io.poll() { - Ok(Async::Ready(())) => { + match Future::poll(Pin::new(&mut io), cx) { + Poll::Ready(Ok(())) => { let event = IdentifyEvent::Sent { peer_id: peer }; - return Async::Ready(NetworkBehaviourAction::GenerateEvent(event)); + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); }, - Ok(Async::NotReady) => { + Poll::Pending => { self.pending_replies.push_back(Reply::Sending { peer, io }); if sending == to_send { // All remaining futures are NotReady @@ -208,12 +208,12 @@ where reply = self.pending_replies.pop_front(); } } - Err(err) => { + Poll::Ready(Err(err)) => { let event = IdentifyEvent::Error { peer_id: peer, - error: ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)) + error: ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err.into())) }; - return Async::Ready(NetworkBehaviourAction::GenerateEvent(event)); + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); }, } } @@ -222,7 +222,7 @@ where } } - Async::NotReady + Poll::Pending } } @@ -248,7 +248,7 @@ pub enum IdentifyEvent { /// The peer with whom the error originated. peer_id: PeerId, /// The error that occurred. - error: ProtocolsHandlerUpgrErr, + error: ProtocolsHandlerUpgrErr, }, } @@ -326,7 +326,7 @@ mod tests { assert_eq!(info.agent_version, "d"); assert!(!info.protocols.is_empty()); assert!(info.listen_addrs.is_empty()); - return Ok(Async::Ready(())) + return Ok(Poll::Ready(())) }, Async::Ready(Some(IdentifyEvent::Sent { .. })) => (), Async::Ready(e) => panic!("{:?}", e), @@ -340,7 +340,7 @@ mod tests { assert_eq!(info.agent_version, "b"); assert!(!info.protocols.is_empty()); assert_eq!(info.listen_addrs.len(), 1); - return Ok(Async::Ready(())) + return Ok(Poll::Ready(())) }, Async::Ready(Some(IdentifyEvent::Sent { .. })) => (), Async::Ready(e) => panic!("{:?}", e), @@ -348,7 +348,7 @@ mod tests { } } - Ok(Async::NotReady) + Ok(Poll::Pending) })) .unwrap(); } diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index 8b197414..4e27effe 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -18,25 +18,19 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::BytesMut; use crate::structs_proto; -use futures::{future::{self, FutureResult}, Async, AsyncSink, Future, Poll, Sink, Stream}; -use futures::try_ready; +use futures::prelude::*; use libp2p_core::{ Multiaddr, PublicKey, - upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, Negotiated} + upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo, Negotiated} }; use log::{debug, trace}; use protobuf::Message as ProtobufMessage; use protobuf::parse_from_bytes as protobuf_parse_from_bytes; use protobuf::RepeatedField; use std::convert::TryFrom; -use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::{fmt, iter}; -use tokio_codec::Framed; -use tokio_io::{AsyncRead, AsyncWrite}; -use unsigned_varint::codec; +use std::{fmt, io, iter, pin::Pin}; /// Configuration for an upgrade to the `Identify` protocol. #[derive(Debug, Clone)] @@ -54,7 +48,7 @@ pub struct RemoteInfo { /// The substream on which a reply is expected to be sent. pub struct ReplySubstream { - inner: Framed>>, + inner: T, } impl fmt::Debug for ReplySubstream { @@ -65,13 +59,15 @@ impl fmt::Debug for ReplySubstream { impl ReplySubstream where - T: AsyncWrite + T: AsyncWrite + Unpin { /// Sends back the requested information on the substream. /// /// Consumes the substream, returning a `ReplyFuture` that resolves /// when the reply has been sent on the underlying connection. - pub fn send(self, info: IdentifyInfo, observed_addr: &Multiaddr) -> ReplyFuture { + pub fn send(mut self, info: IdentifyInfo, observed_addr: &Multiaddr) + -> impl Future> + { debug!("Sending identify info to client"); trace!("Sending: {:?}", info); @@ -90,50 +86,15 @@ where message.set_observedAddr(observed_addr.to_vec()); message.set_protocols(RepeatedField::from_vec(info.protocols)); - let bytes = message - .write_to_bytes() - .expect("writing protobuf failed; should never happen"); - - ReplyFuture { - inner: self.inner, - item: Some(bytes), + async move { + let bytes = message + .write_to_bytes() + .expect("writing protobuf failed; should never happen"); + upgrade::write_one(&mut self.inner, &bytes).await } } } -/// Future returned by `IdentifySender::send()`. Must be processed to the end in order to send -/// the information to the remote. -// Note: we don't use a `futures::sink::Sink` because it requires `T` to implement `Sink`, which -// means that we would require `T: AsyncWrite` in this struct definition. This requirement -// would then propagate everywhere. -#[must_use = "futures do nothing unless polled"] -pub struct ReplyFuture { - /// The Sink where to send the data. - inner: Framed>>, - /// Bytes to send, or `None` if we've already sent them. - item: Option>, -} - -impl Future for ReplyFuture -where T: AsyncWrite -{ - type Item = (); - type Error = IoError; - - fn poll(&mut self) -> Poll { - if let Some(item) = self.item.take() { - if let AsyncSink::NotReady(item) = self.inner.start_send(item)? { - self.item = Some(item); - return Ok(Async::NotReady); - } - } - - // A call to `close()` implies flushing. - try_ready!(self.inner.close()); - Ok(Async::Ready(())) - } -} - /// Information of a peer sent in `Identify` protocol responses. #[derive(Debug, Clone)] pub struct IdentifyInfo { @@ -162,93 +123,60 @@ impl UpgradeInfo for IdentifyProtocolConfig { impl InboundUpgrade for IdentifyProtocolConfig where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin, { type Output = ReplySubstream>; - type Error = IoError; - type Future = FutureResult; + type Error = io::Error; + type Future = future::Ready>; fn upgrade_inbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { trace!("Upgrading inbound connection"); - let inner = Framed::new(socket, codec::UviBytes::default()); - future::ok(ReplySubstream { inner }) + future::ok(ReplySubstream { inner: socket }) } } impl OutboundUpgrade for IdentifyProtocolConfig where - C: AsyncRead + AsyncWrite, + C: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = RemoteInfo; - type Error = IoError; - type Future = IdentifyOutboundFuture>; + type Error = upgrade::ReadOneError; + type Future = Pin>>>; - fn upgrade_outbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { - IdentifyOutboundFuture { - inner: Framed::new(socket, codec::UviBytes::::default()), - shutdown: false, - } - } -} + fn upgrade_outbound(self, mut socket: Negotiated, _: Self::Info) -> Self::Future { + Box::pin(async move { + socket.close().await?; + let msg = upgrade::read_one(&mut socket, 4096).await?; + let (info, observed_addr) = match parse_proto_msg(msg) { + Ok(v) => v, + Err(err) => { + debug!("Failed to parse protobuf message; error = {:?}", err); + return Err(err.into()) + } + }; -/// Future returned by `OutboundUpgrade::upgrade_outbound`. -pub struct IdentifyOutboundFuture { - inner: Framed>, - /// If true, we have finished shutting down the writing part of `inner`. - shutdown: bool, -} + trace!("Remote observes us as {:?}", observed_addr); + trace!("Information received: {:?}", info); -impl Future for IdentifyOutboundFuture -where T: AsyncRead + AsyncWrite, -{ - type Item = RemoteInfo; - type Error = IoError; - - fn poll(&mut self) -> Poll { - if !self.shutdown { - try_ready!(self.inner.close()); - self.shutdown = true; - } - - let msg = match try_ready!(self.inner.poll()) { - Some(i) => i, - None => { - debug!("Identify protocol stream closed before receiving info"); - return Err(IoErrorKind::InvalidData.into()); - } - }; - - debug!("Received identify message"); - - let (info, observed_addr) = match parse_proto_msg(msg) { - Ok(v) => v, - Err(err) => { - debug!("Failed to parse protobuf message; error = {:?}", err); - return Err(err) - } - }; - - trace!("Remote observes us as {:?}", observed_addr); - trace!("Information received: {:?}", info); - - Ok(Async::Ready(RemoteInfo { - info, - observed_addr: observed_addr.clone(), - _priv: () - })) + Ok(RemoteInfo { + info, + observed_addr: observed_addr.clone(), + _priv: () + }) + }) } } // Turns a protobuf message into an `IdentifyInfo` and an observed address. If something bad -// happens, turn it into an `IoError`. -fn parse_proto_msg(msg: BytesMut) -> Result<(IdentifyInfo, Multiaddr), IoError> { - match protobuf_parse_from_bytes::(&msg) { +// happens, turn it into an `io::Error`. +fn parse_proto_msg(msg: impl AsRef<[u8]>) -> Result<(IdentifyInfo, Multiaddr), io::Error> { + match protobuf_parse_from_bytes::(msg.as_ref()) { Ok(mut msg) => { // Turn a `Vec` into a `Multiaddr`. If something bad happens, turn it into - // an `IoError`. - fn bytes_to_multiaddr(bytes: Vec) -> Result { + // an `io::Error`. + fn bytes_to_multiaddr(bytes: Vec) -> Result { Multiaddr::try_from(bytes) - .map_err(|err| IoError::new(IoErrorKind::InvalidData, err)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } let listen_addrs = { @@ -260,7 +188,7 @@ fn parse_proto_msg(msg: BytesMut) -> Result<(IdentifyInfo, Multiaddr), IoError> }; let public_key = PublicKey::from_protobuf_encoding(msg.get_publicKey()) - .map_err(|e| IoError::new(IoErrorKind::InvalidData, e))?; + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let observed_addr = bytes_to_multiaddr(msg.take_observedAddr())?; let info = IdentifyInfo { @@ -274,7 +202,7 @@ fn parse_proto_msg(msg: BytesMut) -> Result<(IdentifyInfo, Multiaddr), IoError> Ok((info, observed_addr)) } - Err(err) => Err(IoError::new(IoErrorKind::InvalidData, err)), + Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)), } } diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index 6be0b952..4c101298 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -28,7 +28,7 @@ tokio-codec = "0.1" tokio-io = "0.1" wasm-timer = "0.1" uint = "0.8" -unsigned-varint = { version = "0.2.1", features = ["codec"] } +unsigned-varint = { git = "https://github.com/tomaka/unsigned-varint", branch = "futures-codec", features = ["codec"] } void = "1.0" [dev-dependencies] diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 5a559433..137bc704 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -640,7 +640,7 @@ where fn poll( &mut self, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent, io::Error, > { // We remove each element from `substreams` one by one and add them back. diff --git a/protocols/noise/Cargo.toml b/protocols/noise/Cargo.toml index 189c61be..000fb508 100644 --- a/protocols/noise/Cargo.toml +++ b/protocols/noise/Cargo.toml @@ -10,7 +10,7 @@ edition = "2018" [dependencies] bytes = "0.4" curve25519-dalek = "1" -futures = "0.1" +futures-preview = "0.3.0-alpha.17" lazy_static = "1.2" libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4" diff --git a/protocols/noise/src/io/handshake.rs b/protocols/noise/src/io/handshake.rs index 93a1f206..f0dac45c 100644 --- a/protocols/noise/src/io/handshake.rs +++ b/protocols/noise/src/io/handshake.rs @@ -25,30 +25,12 @@ mod payload; use crate::error::NoiseError; use crate::protocol::{Protocol, PublicKey, KeypairIdentity}; use libp2p_core::identity; -use futures::{future, Async, Future, future::FutureResult, Poll}; -use std::{mem, io}; -use tokio_io::{io as nio, AsyncWrite, AsyncRead}; +use futures::prelude::*; +use std::{mem, io, task::Poll}; use protobuf::Message; use super::NoiseOutput; -/// A future performing a Noise handshake pattern. -pub struct Handshake( - Box as Future>::Item, - Error = as Future>::Error - > + Send> -); - -impl Future for Handshake { - type Error = NoiseError; - type Item = (RemoteIdentity, NoiseOutput); - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - /// The identity of the remote established during a handshake. pub enum RemoteIdentity { /// The remote provided no identifying information. @@ -131,12 +113,11 @@ where session: Result, identity: KeypairIdentity, identity_x: IdentityExchange - ) -> Handshake { - Handshake(Box::new( - State::new(io, session, identity, identity_x) - .and_then(State::send_identity) - .and_then(State::recv_identity) - .and_then(State::finish))) + ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { + let mut state = State::new(io, session, identity, identity_x); + send_identity(&mut state).await?; + recv_identity(&mut state).await?; + state.finish.await } /// Creates an authenticated Noise handshake for the responder of a @@ -160,12 +141,11 @@ where session: Result, identity: KeypairIdentity, identity_x: IdentityExchange, - ) -> Handshake { - Handshake(Box::new( - State::new(io, session, identity, identity_x) - .and_then(State::recv_identity) - .and_then(State::send_identity) - .and_then(State::finish))) + ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { + let mut state = State::new(io, session, identity, identity_x); + recv_identity(&mut state).await?; + send_identity(&mut state).await?; + state.finish.await } /// Creates an authenticated Noise handshake for the initiator of a @@ -191,13 +171,12 @@ where session: Result, identity: KeypairIdentity, identity_x: IdentityExchange - ) -> Handshake { - Handshake(Box::new( - State::new(io, session, identity, identity_x) - .and_then(State::send_empty) - .and_then(State::recv_identity) - .and_then(State::send_identity) - .and_then(State::finish))) + ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { + let mut state = State::new(io, session, identity, identity_x); + send_empty(&mut state).await?; + send_identity(&mut state).await?; + recv_identity(&mut state).await?; + state.finish.await } /// Creates an authenticated Noise handshake for the responder of a @@ -218,18 +197,17 @@ where /// initiator <-{id}- responder /// initiator -{id}-> responder /// ``` - pub fn rt15_responder( + pub async fn rt15_responder( io: T, session: Result, identity: KeypairIdentity, identity_x: IdentityExchange - ) -> Handshake { - Handshake(Box::new( - State::new(io, session, identity, identity_x) - .and_then(State::recv_empty) - .and_then(State::send_identity) - .and_then(State::recv_identity) - .and_then(State::finish))) + ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { + let mut state = State::new(io, session, identity, identity_x); + recv_empty(&mut state).await?; + send_identity(&mut state).await?; + recv_identity(&mut state).await?; + state.finish().await } } @@ -251,36 +229,6 @@ struct State { send_identity: bool, } -impl io::Read for State { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.io.read(buf) - } -} - -impl io::Write for State { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.io.write(buf) - } - fn flush(&mut self) -> io::Result<()> { - self.io.flush() - } -} - -impl AsyncRead for State { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.io.prepare_uninitialized_buffer(buf) - } - fn read_buf(&mut self, buf: &mut B) -> Poll { - self.io.read_buf(buf) - } -} - -impl AsyncWrite for State { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.shutdown() - } -} - impl State { /// Initializes the state for a new Noise handshake, using the given local /// identity keypair and local DH static public key. The handshake messages @@ -346,30 +294,6 @@ impl State } } -impl State { - /// Creates a future that sends a Noise handshake message with an empty payload. - fn send_empty(self) -> SendEmpty { - SendEmpty { state: SendState::Write(self) } - } - - /// Creates a future that expects to receive a Noise handshake message with an empty payload. - fn recv_empty(self) -> RecvEmpty { - RecvEmpty { state: RecvState::Read(self) } - } - - /// Creates a future that sends a Noise handshake message with a payload identifying - /// the local node to the remote. - fn send_identity(self) -> SendIdentity { - SendIdentity { state: SendIdentityState::Init(self) } - } - - /// Creates a future that expects to receive a Noise handshake message with a - /// payload identifying the remote. - fn recv_identity(self) -> RecvIdentity { - RecvIdentity { state: RecvIdentityState::Init(self) } - } -} - ////////////////////////////////////////////////////////////////////////////// // Handshake Message Futures @@ -378,34 +302,12 @@ impl State { /// A future for receiving a Noise handshake message with an empty payload. /// /// Obtained from [`Handshake::recv_empty`]. -struct RecvEmpty { - state: RecvState -} - -enum RecvState { - Read(State), - Done -} - -impl Future for RecvEmpty +async fn recv_empty(state: &mut State) -> Result<(), NoiseError> where T: AsyncRead { - type Error = NoiseError; - type Item = State; - - fn poll(&mut self) -> Poll { - match mem::replace(&mut self.state, RecvState::Done) { - RecvState::Read(mut st) => { - if !st.io.poll_read(&mut [])?.is_ready() { - self.state = RecvState::Read(st); - return Ok(Async::NotReady) - } - Ok(Async::Ready(st)) - }, - RecvState::Done => panic!("RecvEmpty polled after completion") - } - } + state.io.read(&mut []).await?; + Ok(()) } // SendEmpty ----------------------------------------------------------------- @@ -413,44 +315,13 @@ where /// A future for sending a Noise handshake message with an empty payload. /// /// Obtained from [`Handshake::send_empty`]. -struct SendEmpty { - state: SendState -} - -enum SendState { - Write(State), - Flush(State), - Done -} - -impl Future for SendEmpty +async fn send_empty(state: &mut State) -> Result<(), NoiseError> where T: AsyncWrite { - type Error = NoiseError; - type Item = State; - - fn poll(&mut self) -> Poll { - loop { - match mem::replace(&mut self.state, SendState::Done) { - SendState::Write(mut st) => { - if !st.io.poll_write(&mut [])?.is_ready() { - self.state = SendState::Write(st); - return Ok(Async::NotReady) - } - self.state = SendState::Flush(st); - }, - SendState::Flush(mut st) => { - if !st.io.poll_flush()?.is_ready() { - self.state = SendState::Flush(st); - return Ok(Async::NotReady) - } - return Ok(Async::Ready(st)) - } - SendState::Done => panic!("SendEmpty polled after completion") - } - } - } + state.write(&[]).await?; + state.flush().await?; + Ok(()) } // RecvIdentity -------------------------------------------------------------- @@ -523,71 +394,24 @@ where // SendIdentity -------------------------------------------------------------- -/// A future for sending a Noise handshake message with a payload -/// identifying the local node to the remote. +/// Send a Noise handshake message with a payload identifying the local node to the remote. /// /// Obtained from [`Handshake::send_identity`]. -struct SendIdentity { - state: SendIdentityState -} - -enum SendIdentityState { - Init(State), - WritePayloadLen(nio::WriteAll, [u8; 2]>, Vec), - WritePayload(nio::WriteAll, Vec>), - Flush(State), - Done -} - -impl Future for SendIdentity +async fn send_identity(state: &mut State) -> Result<(), NoiseError> where - T: AsyncWrite, + T: AsyncWrite { - type Error = NoiseError; - type Item = State; - - fn poll(&mut self) -> Poll { - loop { - match mem::replace(&mut self.state, SendIdentityState::Done) { - SendIdentityState::Init(st) => { - let mut pb = payload::Identity::new(); - if st.send_identity { - pb.set_pubkey(st.identity.public.clone().into_protobuf_encoding()); - } - if let Some(ref sig) = st.identity.signature { - pb.set_signature(sig.clone()); - } - let pb_bytes = pb.write_to_bytes()?; - let len = (pb_bytes.len() as u16).to_be_bytes(); - let write_len = nio::write_all(st, len); - self.state = SendIdentityState::WritePayloadLen(write_len, pb_bytes); - }, - SendIdentityState::WritePayloadLen(mut write_len, payload) => { - if let Async::Ready((st, _)) = write_len.poll()? { - self.state = SendIdentityState::WritePayload(nio::write_all(st, payload)); - } else { - self.state = SendIdentityState::WritePayloadLen(write_len, payload); - return Ok(Async::NotReady) - } - }, - SendIdentityState::WritePayload(mut write_payload) => { - if let Async::Ready((st, _)) = write_payload.poll()? { - self.state = SendIdentityState::Flush(st); - } else { - self.state = SendIdentityState::WritePayload(write_payload); - return Ok(Async::NotReady) - } - }, - SendIdentityState::Flush(mut st) => { - if !st.poll_flush()?.is_ready() { - self.state = SendIdentityState::Flush(st); - return Ok(Async::NotReady) - } - return Ok(Async::Ready(st)) - }, - SendIdentityState::Done => panic!("SendIdentity polled after completion") - } - } + let mut pb = payload::Identity::new(); + if st.send_identity { + pb.set_pubkey(st.identity.public.clone().into_protobuf_encoding()); } + if let Some(ref sig) = st.identity.signature { + pb.set_signature(sig.clone()); + } + let pb_bytes = pb.write_to_bytes()?; + let len = (pb_bytes.len() as u16).to_be_bytes(); + st.write_all(&len).await?; + st.write_all(&pb_bytes).await?; + st.flush().await?; + Ok(()) } - diff --git a/protocols/noise/src/lib.rs b/protocols/noise/src/lib.rs index fc6ed25e..97346a52 100644 --- a/protocols/noise/src/lib.rs +++ b/protocols/noise/src/lib.rs @@ -57,11 +57,10 @@ mod protocol; pub use error::NoiseError; pub use io::NoiseOutput; -pub use io::handshake::{Handshake, RemoteIdentity, IdentityExchange}; +pub use io::handshake::{RemoteIdentity, IdentityExchange}; pub use protocol::{Keypair, AuthenticKeypair, KeypairIdentity, PublicKey, SecretKey}; pub use protocol::{Protocol, ProtocolParams, x25519::X25519, IX, IK, XX}; -use futures::{future::{self, FutureResult}, Future}; use libp2p_core::{identity, PeerId, UpgradeInfo, InboundUpgrade, OutboundUpgrade, Negotiated}; use tokio_io::{AsyncRead, AsyncWrite}; use zeroize::Zeroize; diff --git a/protocols/ping/Cargo.toml b/protocols/ping/Cargo.toml index cde291aa..c8899916 100644 --- a/protocols/ping/Cargo.toml +++ b/protocols/ping/Cargo.toml @@ -15,10 +15,9 @@ libp2p-core = { version = "0.12.0", path = "../../core" } libp2p-swarm = { version = "0.2.0", path = "../../swarm" } log = "0.4.1" multiaddr = { package = "parity-multiaddr", version = "0.5.0", path = "../../misc/multiaddr" } -futures = "0.1" +futures-preview = "0.3.0-alpha.17" rand = "0.6" -tokio-io = "0.1" -wasm-timer = "0.1" +wasm-timer = "0.2" void = "1.0" [dev-dependencies] diff --git a/protocols/ping/src/handler.rs b/protocols/ping/src/handler.rs index 0c3116bf..37e9ad17 100644 --- a/protocols/ping/src/handler.rs +++ b/protocols/ping/src/handler.rs @@ -27,10 +27,9 @@ use libp2p_swarm::{ ProtocolsHandlerUpgrErr, ProtocolsHandlerEvent }; -use std::{error::Error, io, fmt, num::NonZeroU32, time::Duration}; +use std::{error::Error, io, fmt, num::NonZeroU32, pin::Pin, task::Context, task::Poll, time::Duration}; use std::collections::VecDeque; -use tokio_io::{AsyncRead, AsyncWrite}; -use wasm_timer::{Delay, Instant}; +use wasm_timer::Delay; use void::Void; /// The configuration for outbound pings. @@ -176,7 +175,7 @@ impl PingHandler { pub fn new(config: PingConfig) -> Self { PingHandler { config, - next_ping: Delay::new(Instant::now()), + next_ping: Delay::new(Duration::new(0, 0)), pending_results: VecDeque::with_capacity(2), failures: 0, _marker: std::marker::PhantomData @@ -186,7 +185,7 @@ impl PingHandler { impl ProtocolsHandler for PingHandler where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin + 'static, { type InEvent = Void; type OutEvent = PingResult; @@ -228,36 +227,36 @@ where } } - fn poll(&mut self) -> Poll, Self::Error> { + fn poll(&mut self, cx: &mut Context) -> Poll> { if let Some(result) = self.pending_results.pop_back() { if let Ok(PingSuccess::Ping { .. }) = result { - let next_ping = Instant::now() + self.config.interval; self.failures = 0; - self.next_ping.reset(next_ping); + self.next_ping.reset(self.config.interval); } if let Err(e) = result { self.failures += 1; if self.failures >= self.config.max_failures.get() { - return Err(e) + return Poll::Ready(ProtocolsHandlerEvent::Close(e)) } else { - return Ok(Async::Ready(ProtocolsHandlerEvent::Custom(Err(e)))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(e))) } } - return Ok(Async::Ready(ProtocolsHandlerEvent::Custom(result))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(result)) } - match self.next_ping.poll() { - Ok(Async::Ready(())) => { - self.next_ping.reset(Instant::now() + self.config.timeout); + match Future::poll(Pin::new(&mut self.next_ping), cx) { + Poll::Ready(Ok(())) => { + self.next_ping.reset(self.config.timeout); let protocol = SubstreamProtocol::new(protocol::Ping) .with_timeout(self.config.timeout); - Ok(Async::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info: (), - })) + }) }, - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(PingFailure::Other { error: Box::new(e) }) + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => + Poll::Ready(ProtocolsHandlerEvent::Close(PingFailure::Other { error: Box::new(e) })) } } } @@ -285,7 +284,7 @@ mod tests { ProtocolsHandlerEvent, PingFailure > { - Runtime::new().unwrap().block_on(future::poll_fn(|| h.poll() )) + futures::executor::block_on(future::poll_fn(|| h.poll() )) } #[test] diff --git a/protocols/ping/src/lib.rs b/protocols/ping/src/lib.rs index 1353ffa1..38d0df4f 100644 --- a/protocols/ping/src/lib.rs +++ b/protocols/ping/src/lib.rs @@ -50,9 +50,7 @@ use handler::PingHandler; use futures::prelude::*; use libp2p_core::{ConnectedPoint, Multiaddr, PeerId}; use libp2p_swarm::{NetworkBehaviour, NetworkBehaviourAction, PollParameters}; -use std::collections::VecDeque; -use std::marker::PhantomData; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{collections::VecDeque, marker::PhantomData, task::Context, task::Poll}; use void::Void; /// `Ping` is a [`NetworkBehaviour`] that responds to inbound pings and @@ -95,7 +93,7 @@ impl Default for Ping { impl NetworkBehaviour for Ping where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin + 'static, { type ProtocolsHandler = PingHandler; type OutEvent = PingEvent; @@ -116,12 +114,13 @@ where self.events.push_front(PingEvent { peer, result }) } - fn poll(&mut self, _: &mut impl PollParameters) -> Async> + fn poll(&mut self, _: &mut Context, _: &mut impl PollParameters) + -> Poll> { if let Some(e) = self.events.pop_back() { - Async::Ready(NetworkBehaviourAction::GenerateEvent(e)) + Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)) } else { - Async::NotReady + Poll::Pending } } } diff --git a/protocols/ping/src/protocol.rs b/protocols/ping/src/protocol.rs index 926aad03..8a3e7d53 100644 --- a/protocols/ping/src/protocol.rs +++ b/protocols/ping/src/protocol.rs @@ -18,12 +18,11 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::{prelude::*, future, try_ready}; -use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, upgrade::Negotiated}; +use futures::prelude::*; +use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, Negotiated}; use log::debug; use rand::{distributions, prelude::*}; -use std::{io, iter, time::Duration}; -use tokio_io::{io as nio, AsyncRead, AsyncWrite}; +use std::{io, iter, pin::Pin, time::Duration}; use wasm_timer::Instant; /// Represents a prototype for an upgrade to handle the ping protocol. @@ -54,126 +53,50 @@ impl UpgradeInfo for Ping { } } -type RecvPing = nio::ReadExact, [u8; 32]>; -type SendPong = nio::WriteAll, [u8; 32]>; -type Flush = nio::Flush>; -type Shutdown = nio::Shutdown>; - impl InboundUpgrade for Ping where - TSocket: AsyncRead + AsyncWrite, + TSocket: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = (); type Error = io::Error; - type Future = future::Map< - future::AndThen< - future::AndThen< - future::AndThen< - RecvPing, - SendPong, fn((Negotiated, [u8; 32])) -> SendPong>, - Flush, fn((Negotiated, [u8; 32])) -> Flush>, - Shutdown, fn(Negotiated) -> Shutdown>, - fn(Negotiated) -> ()>; + type Future = Pin>>>; - #[inline] - fn upgrade_inbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { - nio::read_exact(socket, [0; 32]) - .and_then:: _, _>(|(sock, buf)| nio::write_all(sock, buf)) - .and_then:: _, _>(|(sock, _)| nio::flush(sock)) - .and_then:: _, _>(|sock| nio::shutdown(sock)) - .map(|_| ()) + fn upgrade_inbound(self, mut socket: Negotiated, _: Self::Info) -> Self::Future { + Box::pin(async move { + let mut payload = [0u8; 32]; + socket.read_exact(&mut payload).await?; + socket.write_all(&payload).await?; + socket.close().await?; + Ok(()) + }) } } impl OutboundUpgrade for Ping where - TSocket: AsyncRead + AsyncWrite, + TSocket: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = Duration; type Error = io::Error; - type Future = PingDialer>; + type Future = Pin>>>; - #[inline] - fn upgrade_outbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { + fn upgrade_outbound(self, mut socket: Negotiated, _: Self::Info) -> Self::Future { let payload: [u8; 32] = thread_rng().sample(distributions::Standard); debug!("Preparing ping payload {:?}", payload); - PingDialer { - state: PingDialerState::Write { - inner: nio::write_all(socket, payload), - }, - } - } -} + Box::pin(async move { + socket.write_all(&payload).await?; + socket.close().await?; + let started = Instant::now(); -/// A `PingDialer` is a future that sends a ping and expects to receive a pong. -pub struct PingDialer { - state: PingDialerState -} - -enum PingDialerState { - Write { - inner: nio::WriteAll, - }, - Flush { - inner: nio::Flush, - payload: [u8; 32], - }, - Read { - inner: nio::ReadExact, - payload: [u8; 32], - started: Instant, - }, - Shutdown { - inner: nio::Shutdown, - rtt: Duration, - }, -} - -impl Future for PingDialer -where - TSocket: AsyncRead + AsyncWrite, -{ - type Item = Duration; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - self.state = match self.state { - PingDialerState::Write { ref mut inner } => { - let (socket, payload) = try_ready!(inner.poll()); - PingDialerState::Flush { - inner: nio::flush(socket), - payload, - } - }, - PingDialerState::Flush { ref mut inner, payload } => { - let socket = try_ready!(inner.poll()); - let started = Instant::now(); - PingDialerState::Read { - inner: nio::read_exact(socket, [0; 32]), - payload, - started, - } - }, - PingDialerState::Read { ref mut inner, payload, started } => { - let (socket, payload_received) = try_ready!(inner.poll()); - let rtt = started.elapsed(); - if payload_received != payload { - return Err(io::Error::new( - io::ErrorKind::InvalidData, "Ping payload mismatch")); - } - PingDialerState::Shutdown { - inner: nio::shutdown(socket), - rtt, - } - }, - PingDialerState::Shutdown { ref mut inner, rtt } => { - try_ready!(inner.poll()); - return Ok(Async::Ready(rtt)); - }, + let mut recv_payload = [0u8; 32]; + socket.read_exact(&mut recv_payload).await?; + if recv_payload == payload { + Ok(started.elapsed()) + } else { + Err(io::Error::new(io::ErrorKind::InvalidData, "Ping payload mismatch")) } - } + }) } } @@ -199,7 +122,7 @@ mod tests { let mut listener = MemoryTransport.listen_on(mem_addr).unwrap(); let listener_addr = - if let Ok(Async::Ready(Some(ListenerEvent::NewAddress(a)))) = listener.poll() { + if let Ok(Poll::Ready(Some(ListenerEvent::NewAddress(a)))) = listener.poll() { a } else { panic!("MemoryTransport not listening on an address!"); diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index 6d6b98c2..dbb73f15 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -98,7 +98,7 @@ fn ping() { }); let result = peer1.select(peer2).map_err(|e| panic!(e)); - let ((p1, p2, rtt), _) = Runtime::new().unwrap().block_on(result).unwrap(); + let ((p1, p2, rtt), _) = futures::executor::block_on(result).unwrap(); assert!(p1 == peer1_id && p2 == peer2_id || p1 == peer2_id && p2 == peer1_id); assert!(rtt < Duration::from_millis(50)); } diff --git a/protocols/plaintext/Cargo.toml b/protocols/plaintext/Cargo.toml index 5b04674b..912c5a4c 100644 --- a/protocols/plaintext/Cargo.toml +++ b/protocols/plaintext/Cargo.toml @@ -10,7 +10,7 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] -futures = "0.1" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../../core" } void = "1" diff --git a/protocols/plaintext/src/lib.rs b/protocols/plaintext/src/lib.rs index c8c6aafb..c4cda8e6 100644 --- a/protocols/plaintext/src/lib.rs +++ b/protocols/plaintext/src/lib.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::future::{self, FutureResult}; +use futures::future::{self, Ready}; use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, upgrade::Negotiated}; use std::iter; use void::Void; @@ -38,20 +38,20 @@ impl UpgradeInfo for PlainTextConfig { impl InboundUpgrade for PlainTextConfig { type Output = Negotiated; type Error = Void; - type Future = FutureResult, Self::Error>; + type Future = Ready, Self::Error>>; fn upgrade_inbound(self, i: Negotiated, _: Self::Info) -> Self::Future { - future::ok(i) + future::ready(Ok(i)) } } impl OutboundUpgrade for PlainTextConfig { type Output = Negotiated; type Error = Void; - type Future = FutureResult, Self::Error>; + type Future = Ready, Self::Error>>; fn upgrade_outbound(self, i: Negotiated, _: Self::Info) -> Self::Future { - future::ok(i) + future::ready(Ok(i)) } } diff --git a/protocols/secio/Cargo.toml b/protocols/secio/Cargo.toml index c65d13bd..1c479dae 100644 --- a/protocols/secio/Cargo.toml +++ b/protocols/secio/Cargo.toml @@ -11,7 +11,8 @@ categories = ["network-programming", "asynchronous"] [dependencies] bytes = "0.4" -futures = "0.1" +futures-preview = "0.3.0-alpha.17" +futures_codec = "0.2.5" libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4.6" protobuf = "2.3" @@ -22,9 +23,9 @@ twofish = "0.2.0" ctr = "0.3" lazy_static = "1.2.0" rw-stream-sink = { version = "0.1.1", path = "../../misc/rw-stream-sink" } -tokio-io = "0.1.0" sha2 = "0.8.0" hmac = "0.7.0" +unsigned-varint = { git = "https://github.com/tomaka/unsigned-varint", branch = "futures-codec", features = ["codec"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] ring = { version = "0.14", features = ["use_heap"], default-features = false } @@ -43,11 +44,10 @@ secp256k1 = [] aes-all = ["aesni"] [dev-dependencies] +async-std = "0.99" criterion = "0.2" -libp2p-tcp = { version = "0.12.0", path = "../../transports/tcp" } libp2p-mplex = { version = "0.12.0", path = "../../muxers/mplex" } -tokio = "0.1" -tokio-tcp = "0.1" +libp2p-tcp = { version = "0.12.0", path = "../../transports/tcp" } [[bench]] name = "bench" diff --git a/protocols/secio/src/codec/decode.rs b/protocols/secio/src/codec/decode.rs index 4b0c73b3..7a80bec0 100644 --- a/protocols/secio/src/codec/decode.rs +++ b/protocols/secio/src/codec/decode.rs @@ -20,19 +20,14 @@ //! Individual messages decoding. -use bytes::BytesMut; use super::{Hmac, StreamCipher}; use crate::error::SecioError; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Async; -use futures::Poll; -use futures::StartSend; +use futures::prelude::*; use log::debug; -use std::cmp::min; +use std::{cmp::min, pin::Pin, task::Context, task::Poll}; -/// Wraps around a `Stream`. The buffers produced by the underlying stream +/// Wraps around a `Stream>`. The buffers produced by the underlying stream /// are decoded using the cipher and hmac. /// /// This struct implements `Stream`, whose stream item are frames of data without the length @@ -52,7 +47,6 @@ impl DecoderMiddleware { /// /// The `nonce` parameter denotes a sequence of bytes which are expected to be found at the /// beginning of the stream and are checked for equality. - #[inline] pub fn new(raw_stream: S, cipher: StreamCipher, hmac: Hmac, nonce: Vec) -> DecoderMiddleware { DecoderMiddleware { cipher_state: cipher, @@ -65,24 +59,22 @@ impl DecoderMiddleware { impl Stream for DecoderMiddleware where - S: Stream, + S: TryStream + Unpin, S::Error: Into, { - type Item = Vec; - type Error = SecioError; + type Item = Result, SecioError>; - #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - let frame = match self.raw_stream.poll() { - Ok(Async::Ready(Some(t))) => t, - Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(err.into()), + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let frame = match TryStream::try_poll_next(Pin::new(&mut self.raw_stream), cx) { + Poll::Ready(Some(Ok(t))) => t, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))), }; if frame.len() < self.hmac.num_bytes() { debug!("frame too short when decoding secio frame"); - return Err(SecioError::FrameTooShort); + return Poll::Ready(Some(Err(SecioError::FrameTooShort))); } let content_length = frame.len() - self.hmac.num_bytes(); { @@ -91,7 +83,7 @@ where if self.hmac.verify(crypted_data, expected_hash).is_err() { debug!("hmac mismatch when decoding secio frame"); - return Err(SecioError::HmacNotMatching); + return Poll::Ready(Some(Err(SecioError::HmacNotMatching))); } } @@ -103,35 +95,35 @@ where if !self.nonce.is_empty() { let n = min(data_buf.len(), self.nonce.len()); if data_buf[.. n] != self.nonce[.. n] { - return Err(SecioError::NonceVerificationFailed) + return Poll::Ready(Some(Err(SecioError::NonceVerificationFailed))) } self.nonce.drain(.. n); data_buf.drain(.. n); } - Ok(Async::Ready(Some(data_buf))) + Poll::Ready(Some(Ok(data_buf))) } } -impl Sink for DecoderMiddleware +impl Sink for DecoderMiddleware where - S: Sink, + S: Sink + Unpin, { - type SinkItem = S::SinkItem; - type SinkError = S::SinkError; + type Error = S::Error; - #[inline] - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.raw_stream.start_send(item) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_ready(Pin::new(&mut self.raw_stream), cx) } - #[inline] - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.raw_stream.poll_complete() + fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + Sink::start_send(Pin::new(&mut self.raw_stream), item) } - #[inline] - fn close(&mut self) -> Poll<(), Self::SinkError> { - self.raw_stream.close() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_flush(Pin::new(&mut self.raw_stream), cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_close(Pin::new(&mut self.raw_stream), cx) } } diff --git a/protocols/secio/src/codec/encode.rs b/protocols/secio/src/codec/encode.rs index 36c3bcad..a0f0c04c 100644 --- a/protocols/secio/src/codec/encode.rs +++ b/protocols/secio/src/codec/encode.rs @@ -20,9 +20,9 @@ //! Individual messages encoding. -use bytes::BytesMut; use super::{Hmac, StreamCipher}; use futures::prelude::*; +use std::{pin::Pin, task::Context, task::Poll}; /// Wraps around a `Sink`. Encodes the buffers passed to it and passes it to the underlying sink. /// @@ -35,7 +35,6 @@ pub struct EncoderMiddleware { cipher_state: StreamCipher, hmac: Hmac, raw_sink: S, - pending: Option // buffer encrypted data which can not be sent right away } impl EncoderMiddleware { @@ -44,68 +43,44 @@ impl EncoderMiddleware { cipher_state: cipher, hmac, raw_sink: raw, - pending: None } } } -impl Sink for EncoderMiddleware +impl Sink> for EncoderMiddleware where - S: Sink, + S: Sink> + Unpin, { - type SinkItem = BytesMut; - type SinkError = S::SinkError; + type Error = S::Error; - fn start_send(&mut self, mut data_buf: Self::SinkItem) -> StartSend { - if let Some(data) = self.pending.take() { - if let AsyncSink::NotReady(data) = self.raw_sink.start_send(data)? { - self.pending = Some(data); - return Ok(AsyncSink::NotReady(data_buf)) - } - } - debug_assert!(self.pending.is_none()); + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_ready(Pin::new(&mut self.raw_sink), cx) + } + + fn start_send(mut self: Pin<&mut Self>, mut data_buf: Vec) -> Result<(), Self::Error> { // TODO if SinkError gets refactor to SecioError, then use try_apply_keystream self.cipher_state.encrypt(&mut data_buf[..]); let signature = self.hmac.sign(&data_buf[..]); data_buf.extend_from_slice(signature.as_ref()); - if let AsyncSink::NotReady(data) = self.raw_sink.start_send(data_buf)? { - self.pending = Some(data) - } - Ok(AsyncSink::Ready) + Sink::start_send(Pin::new(&mut self.raw_sink), data_buf) } - #[inline] - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - if let Some(data) = self.pending.take() { - if let AsyncSink::NotReady(data) = self.raw_sink.start_send(data)? { - self.pending = Some(data); - return Ok(Async::NotReady) - } - } - self.raw_sink.poll_complete() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_flush(Pin::new(&mut self.raw_sink), cx) } - #[inline] - fn close(&mut self) -> Poll<(), Self::SinkError> { - if let Some(data) = self.pending.take() { - if let AsyncSink::NotReady(data) = self.raw_sink.start_send(data)? { - self.pending = Some(data); - return Ok(Async::NotReady) - } - } - self.raw_sink.close() + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_close(Pin::new(&mut self.raw_sink), cx) } } impl Stream for EncoderMiddleware where - S: Stream, + S: Stream + Unpin, { type Item = S::Item; - type Error = S::Error; - #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - self.raw_sink.poll() + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Stream::poll_next(Pin::new(&mut self.raw_sink), cx) } } diff --git a/protocols/secio/src/codec/mod.rs b/protocols/secio/src/codec/mod.rs index 51a711cc..73c06e09 100644 --- a/protocols/secio/src/codec/mod.rs +++ b/protocols/secio/src/codec/mod.rs @@ -24,18 +24,18 @@ use self::decode::DecoderMiddleware; use self::encode::EncoderMiddleware; -use aes_ctr::stream_cipher; use crate::algo_support::Digest; +use futures::prelude::*; +use aes_ctr::stream_cipher; use hmac::{self, Mac}; use sha2::{Sha256, Sha512}; -use tokio_io::codec::length_delimited; -use tokio_io::{AsyncRead, AsyncWrite}; +use unsigned_varint::codec::UviBytes; mod decode; mod encode; /// Type returned by `full_codec`. -pub type FullCodec = DecoderMiddleware>>; +pub type FullCodec = DecoderMiddleware>>>>; pub type StreamCipher = Box; @@ -108,7 +108,7 @@ impl Hmac { /// The conversion between the stream/sink items and the socket is done with the given cipher and /// hash algorithm (which are generally decided during the handshake). pub fn full_codec( - socket: length_delimited::Framed, + socket: futures_codec::Framed>>, cipher_encoding: StreamCipher, encoding_hmac: Hmac, cipher_decoder: StreamCipher, @@ -116,7 +116,7 @@ pub fn full_codec( remote_nonce: Vec ) -> FullCodec where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { let encoder = EncoderMiddleware::new(socket, cipher_encoding, encoding_hmac); DecoderMiddleware::new(encoder, cipher_decoder, decoding_hmac, remote_nonce) @@ -124,56 +124,45 @@ where #[cfg(test)] mod tests { - use tokio::runtime::current_thread::Runtime; - use tokio_tcp::{TcpListener, TcpStream}; - use crate::stream_cipher::{ctr, Cipher}; - use super::full_codec; - use super::DecoderMiddleware; - use super::EncoderMiddleware; - use super::Hmac; + use super::{full_codec, DecoderMiddleware, EncoderMiddleware, Hmac}; use crate::algo_support::Digest; + use crate::stream_cipher::{ctr, Cipher}; use crate::error::SecioError; + use async_std::net::{TcpListener, TcpStream}; use bytes::BytesMut; - use futures::sync::mpsc::channel; - use futures::{Future, Sink, Stream, stream}; - use rand; - use std::io::Error as IoError; - use tokio_io::codec::length_delimited::Framed; + use futures::{prelude::*, channel::mpsc, channel::oneshot}; + use futures_codec::Framed; + use unsigned_varint::codec::UviBytes; - const NULL_IV : [u8; 16] = [0;16]; + const NULL_IV : [u8; 16] = [0; 16]; #[test] fn raw_encode_then_decode() { - let (data_tx, data_rx) = channel::(256); - let data_tx = data_tx.sink_map_err::<_, IoError>(|_| panic!()); - let data_rx = data_rx.map_err::(|_| panic!()); + let (data_tx, data_rx) = mpsc::channel::>(256); + let data_rx = data_rx.map(BytesMut::from); let cipher_key: [u8; 32] = rand::random(); let hmac_key: [u8; 32] = rand::random(); - - let encoder = EncoderMiddleware::new( + let mut encoder = EncoderMiddleware::new( data_tx, ctr(Cipher::Aes256, &cipher_key, &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key), ); - let decoder = DecoderMiddleware::new( - data_rx, + + let mut decoder = DecoderMiddleware::new( + data_rx.map(|v| Ok::<_, SecioError>(v)), ctr(Cipher::Aes256, &cipher_key, &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key), Vec::new() ); let data = b"hello world"; - - let data_sent = encoder.send(BytesMut::from(data.to_vec())).from_err(); - let data_received = decoder.into_future().map(|(n, _)| n).map_err(|(e, _)| e); - let mut rt = Runtime::new().unwrap(); - - let (_, decoded) = rt.block_on(data_sent.join(data_received)) - .map_err(|_| ()) - .unwrap(); - assert_eq!(&decoded.unwrap()[..], &data[..]); + futures::executor::block_on(async move { + encoder.send(data.to_vec()).await.unwrap(); + let rx = decoder.next().await.unwrap().unwrap(); + assert_eq!(rx, data); + }); } fn full_codec_encode_then_decode(cipher: Cipher) { @@ -185,53 +174,44 @@ mod tests { let data = b"hello world"; let data_clone = data.clone(); let nonce = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); - let listener_addr = listener.local_addr().unwrap(); + let (l_a_tx, l_a_rx) = oneshot::channel(); let nonce2 = nonce.clone(); - let server = listener.incoming() - .into_future() - .map_err(|(e, _)| e) - .map(move |(connec, _)| { - full_codec( - Framed::new(connec.unwrap()), - ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), - Hmac::from_key(Digest::Sha256, &hmac_key), - ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), - Hmac::from_key(Digest::Sha256, &hmac_key), - nonce2 - ) - }, - ); + let server = async { + let listener = TcpListener::bind(&"127.0.0.1:0").await.unwrap(); + let listener_addr = listener.local_addr().unwrap(); + l_a_tx.send(listener_addr).unwrap(); - let client = TcpStream::connect(&listener_addr) - .map_err(|e| e.into()) - .map(move |stream| { - full_codec( - Framed::new(stream), - ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), - Hmac::from_key(Digest::Sha256, &hmac_key_clone), - ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), - Hmac::from_key(Digest::Sha256, &hmac_key_clone), - Vec::new() - ) - }); + let (connec, _) = listener.accept().await.unwrap(); + let codec = full_codec( + Framed::new(connec, UviBytes::default()), + ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), + Hmac::from_key(Digest::Sha256, &hmac_key), + ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), + Hmac::from_key(Digest::Sha256, &hmac_key), + nonce2.clone() + ); - let fin = server - .join(client) - .from_err::() - .and_then(|(server, client)| { - client - .send_all(stream::iter_ok::<_, IoError>(vec![nonce.into(), data_clone[..].into()])) - .map(move |_| server) - .from_err() - }) - .and_then(|server| server.concat2().from_err()); + let outcome = codec.map(|v| v.unwrap()).concat().await; + assert_eq!(outcome, data_clone); + }; - let mut rt = Runtime::new().unwrap(); - let received = rt.block_on(fin).unwrap(); - assert_eq!(received, data); + let client = async { + let listener_addr = l_a_rx.await.unwrap(); + let stream = TcpStream::connect(&listener_addr).await.unwrap(); + let mut codec = full_codec( + Framed::new(stream, UviBytes::default()), + ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), + Hmac::from_key(Digest::Sha256, &hmac_key_clone), + ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), + Hmac::from_key(Digest::Sha256, &hmac_key_clone), + Vec::new() + ); + codec.send(nonce.into()).await.unwrap(); + codec.send(data.to_vec().into()).await.unwrap(); + }; + + futures::executor::block_on(future::join(client, server)); } #[test] diff --git a/protocols/secio/src/exchange/impl_ring.rs b/protocols/secio/src/exchange/impl_ring.rs index 46a0943f..888dc963 100644 --- a/protocols/secio/src/exchange/impl_ring.rs +++ b/protocols/secio/src/exchange/impl_ring.rs @@ -43,7 +43,7 @@ pub type AgreementPrivateKey = ring_agreement::EphemeralPrivateKey; /// Generates a new key pair as part of the exchange. /// /// Returns the opaque private key and the corresponding public key. -pub fn generate_agreement(algorithm: KeyAgreement) -> impl Future), Error = SecioError> { +pub fn generate_agreement(algorithm: KeyAgreement) -> impl Future), SecioError>> { let rng = ring_rand::SystemRandom::new(); match ring_agreement::EphemeralPrivateKey::generate(algorithm.into(), &rng) { @@ -51,22 +51,22 @@ pub fn generate_agreement(algorithm: KeyAgreement) -> impl Future { debug!("failed to generate ECDH key"); - future::err(SecioError::EphemeralKeyGenerationFailed) + future::ready(Err(SecioError::EphemeralKeyGenerationFailed)) }, } } /// Finish the agreement. On success, returns the shared key that both remote agreed upon. pub fn agree(algorithm: KeyAgreement, my_private_key: AgreementPrivateKey, other_public_key: &[u8], _out_size: usize) - -> impl Future, Error = SecioError> + -> impl Future, SecioError>> { - ring_agreement::agree_ephemeral(my_private_key, algorithm.into(), - UntrustedInput::from(other_public_key), - SecioError::SecretGenerationFailed, - |key_material| Ok(key_material.to_vec())) - .into_future() + let ret = ring_agreement::agree_ephemeral(my_private_key, algorithm.into(), + UntrustedInput::from(other_public_key), + SecioError::SecretGenerationFailed, + |key_material| Ok(key_material.to_vec())); + future::ready(ret) } diff --git a/protocols/secio/src/exchange/mod.rs b/protocols/secio/src/exchange/mod.rs index bb59b4e6..5fdecbb8 100644 --- a/protocols/secio/src/exchange/mod.rs +++ b/protocols/secio/src/exchange/mod.rs @@ -44,14 +44,14 @@ pub struct AgreementPrivateKey(platform::AgreementPrivateKey); /// /// Returns the opaque private key and the corresponding public key. #[inline] -pub fn generate_agreement(algorithm: KeyAgreement) -> impl Future), Error = SecioError> { - platform::generate_agreement(algorithm).map(|(pr, pu)| (AgreementPrivateKey(pr), pu)) +pub fn generate_agreement(algorithm: KeyAgreement) -> impl Future), SecioError>> { + platform::generate_agreement(algorithm).map_ok(|(pr, pu)| (AgreementPrivateKey(pr), pu)) } /// Finish the agreement. On success, returns the shared key that both remote agreed upon. #[inline] pub fn agree(algorithm: KeyAgreement, my_private_key: AgreementPrivateKey, other_public_key: &[u8], out_size: usize) - -> impl Future, Error = SecioError> + -> impl Future, SecioError>> { platform::agree(algorithm, my_private_key.0, other_public_key, out_size) } diff --git a/protocols/secio/src/handshake.rs b/protocols/secio/src/handshake.rs index 6e0e989f..b90ea93a 100644 --- a/protocols/secio/src/handshake.rs +++ b/protocols/secio/src/handshake.rs @@ -19,15 +19,11 @@ // DEALINGS IN THE SOFTWARE. use crate::algo_support; -use bytes::BytesMut; use crate::codec::{full_codec, FullCodec, Hmac}; -use crate::stream_cipher::{Cipher, ctr}; +use crate::stream_cipher::ctr; use crate::error::SecioError; use crate::exchange; -use futures::future; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Future; +use futures::prelude::*; use libp2p_core::PublicKey; use log::{debug, trace}; use protobuf::parse_from_bytes as protobuf_parse_from_bytes; @@ -37,447 +33,291 @@ use sha2::{Digest as ShaDigestTrait, Sha256}; use std::cmp::{self, Ordering}; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use crate::structs_proto::{Exchange, Propose}; -use tokio_io::codec::length_delimited; -use tokio_io::{AsyncRead, AsyncWrite}; -use crate::{KeyAgreement, SecioConfig}; - -// This struct contains the whole context of a handshake, and is filled progressively -// throughout the various parts of the handshake. -struct HandshakeContext { - config: SecioConfig, - state: T -} - -// HandshakeContext<()> --with_local-> HandshakeContext -struct Local { - // Locally-generated random number. The array size can be changed without any repercussion. - nonce: [u8; 16], - // Our encoded local public key - public_key_encoded: Vec, - // Our local proposition's raw bytes: - proposition_bytes: Vec -} - -// HandshakeContext --with_remote-> HandshakeContext -struct Remote { - local: Local, - // The remote's proposition's raw bytes: - proposition_bytes: BytesMut, - // The remote's public key: - public_key: PublicKey, - // The remote's `nonce`. - // If the NONCE size is actually part of the protocol, we can change this to a fixed-size - // array instead of a `Vec`. - nonce: Vec, - // Set to `ordering( - // hash(concat(remote-pubkey, local-none)), - // hash(concat(local-pubkey, remote-none)) - // )`. - // `Ordering::Equal` is an invalid value (as it would mean we're talking to ourselves). - // - // Since everything is symmetrical, this value is used to determine what should be ours - // and what should be the remote's. - hashes_ordering: Ordering, - // Crypto algorithms chosen for the communication: - chosen_exchange: KeyAgreement, - chosen_cipher: Cipher, - chosen_hash: algo_support::Digest, -} - -// HandshakeContext --with_ephemeral-> HandshakeContext -struct Ephemeral { - remote: Remote, - // Ephemeral keypair generated for the handshake: - local_tmp_priv_key: exchange::AgreementPrivateKey, - local_tmp_pub_key: Vec -} - -// HandshakeContext --take_private_key-> HandshakeContext -struct PubEphemeral { - remote: Remote, - local_tmp_pub_key: Vec -} - -impl HandshakeContext<()> { - fn new(config: SecioConfig) -> Self { - HandshakeContext { - config, - state: () - } - } - - // Setup local proposition. - fn with_local(self) -> Result, SecioError> { - let mut nonce = [0; 16]; - rand::thread_rng() - .try_fill_bytes(&mut nonce) - .map_err(|_| SecioError::NonceGenerationFailed)?; - - let public_key_encoded = self.config.key.public().into_protobuf_encoding(); - - // Send our proposition with our nonce, public key and supported protocols. - let mut proposition = Propose::new(); - proposition.set_rand(nonce.to_vec()); - proposition.set_pubkey(public_key_encoded.clone()); - - if let Some(ref p) = self.config.agreements_prop { - trace!("agreements proposition: {}", p); - proposition.set_exchanges(p.clone()) - } else { - trace!("agreements proposition: {}", algo_support::DEFAULT_AGREEMENTS_PROPOSITION); - proposition.set_exchanges(algo_support::DEFAULT_AGREEMENTS_PROPOSITION.into()) - } - - if let Some(ref p) = self.config.ciphers_prop { - trace!("ciphers proposition: {}", p); - proposition.set_ciphers(p.clone()) - } else { - trace!("ciphers proposition: {}", algo_support::DEFAULT_CIPHERS_PROPOSITION); - proposition.set_ciphers(algo_support::DEFAULT_CIPHERS_PROPOSITION.into()) - } - - if let Some(ref p) = self.config.digests_prop { - trace!("digests proposition: {}", p); - proposition.set_hashes(p.clone()) - } else { - trace!("digests proposition: {}", algo_support::DEFAULT_DIGESTS_PROPOSITION); - proposition.set_hashes(algo_support::DEFAULT_DIGESTS_PROPOSITION.into()) - } - - let proposition_bytes = proposition.write_to_bytes()?; - - Ok(HandshakeContext { - config: self.config, - state: Local { - nonce, - public_key_encoded, - proposition_bytes - } - }) - } -} - -impl HandshakeContext { - // Process remote proposition. - fn with_remote(self, b: BytesMut) -> Result, SecioError> { - let mut prop = match protobuf_parse_from_bytes::(&b) { - Ok(prop) => prop, - Err(_) => { - debug!("failed to parse remote's proposition protobuf message"); - return Err(SecioError::HandshakeParsingFailure); - } - }; - - let public_key_encoded = prop.take_pubkey(); - let nonce = prop.take_rand(); - - let pubkey = match PublicKey::from_protobuf_encoding(&public_key_encoded) { - Ok(p) => p, - Err(_) => { - debug!("failed to parse remote's proposition's pubkey protobuf"); - return Err(SecioError::HandshakeParsingFailure); - }, - }; - - // In order to determine which protocols to use, we compute two hashes and choose - // based on which hash is larger. - let hashes_ordering = { - let oh1 = { - let mut ctx = Sha256::new(); - ctx.input(&public_key_encoded); - ctx.input(&self.state.nonce); - ctx.result() - }; - - let oh2 = { - let mut ctx = Sha256::new(); - ctx.input(&self.state.public_key_encoded); - ctx.input(&nonce); - ctx.result() - }; - - oh1.as_ref().cmp(&oh2.as_ref()) - }; - - let chosen_exchange = { - let ours = self.config.agreements_prop.as_ref() - .map(|s| s.as_ref()) - .unwrap_or(algo_support::DEFAULT_AGREEMENTS_PROPOSITION); - let theirs = &prop.get_exchanges(); - match algo_support::select_agreement(hashes_ordering, ours, theirs) { - Ok(a) => a, - Err(err) => { - debug!("failed to select an exchange protocol"); - return Err(err); - } - } - }; - - let chosen_cipher = { - let ours = self.config.ciphers_prop.as_ref() - .map(|s| s.as_ref()) - .unwrap_or(algo_support::DEFAULT_CIPHERS_PROPOSITION); - let theirs = &prop.get_ciphers(); - match algo_support::select_cipher(hashes_ordering, ours, theirs) { - Ok(a) => { - debug!("selected cipher: {:?}", a); - a - } - Err(err) => { - debug!("failed to select a cipher protocol"); - return Err(err); - } - } - }; - - let chosen_hash = { - let ours = self.config.digests_prop.as_ref() - .map(|s| s.as_ref()) - .unwrap_or(algo_support::DEFAULT_DIGESTS_PROPOSITION); - let theirs = &prop.get_hashes(); - match algo_support::select_digest(hashes_ordering, ours, theirs) { - Ok(a) => { - debug!("selected hash: {:?}", a); - a - } - Err(err) => { - debug!("failed to select a hash protocol"); - return Err(err); - } - } - }; - - Ok(HandshakeContext { - config: self.config, - state: Remote { - local: self.state, - proposition_bytes: b, - public_key: pubkey, - nonce, - hashes_ordering, - chosen_exchange, - chosen_cipher, - chosen_hash - } - }) - } -} - -impl HandshakeContext { - fn with_ephemeral(self, sk: exchange::AgreementPrivateKey, pk: Vec) -> HandshakeContext { - HandshakeContext { - config: self.config, - state: Ephemeral { - remote: self.state, - local_tmp_priv_key: sk, - local_tmp_pub_key: pk - } - } - } -} - -impl HandshakeContext { - fn take_private_key(self) -> (HandshakeContext, exchange::AgreementPrivateKey) { - let context = HandshakeContext { - config: self.config, - state: PubEphemeral { - remote: self.state.remote, - local_tmp_pub_key: self.state.local_tmp_pub_key - } - }; - (context, self.state.local_tmp_priv_key) - } -} +use crate::SecioConfig; /// Performs a handshake on the given socket. /// /// This function expects that the remote is identified with `remote_public_key`, and the remote -/// will expect that we are identified with `local_key`.Any mismatch somewhere will produce a +/// will expect that we are identified with `local_key`. Any mismatch somewhere will produce a /// `SecioError`. /// /// On success, returns an object that implements the `Sink` and `Stream` trait whose items are /// buffers of data, plus the public key of the remote, plus the ephemeral public key used during /// negotiation. -pub fn handshake<'a, S: 'a>(socket: S, config: SecioConfig) - -> impl Future, PublicKey, Vec), Error = SecioError> +pub async fn handshake<'a, S: 'a>(socket: S, config: SecioConfig) + -> Result<(FullCodec, PublicKey, Vec), SecioError> where - S: AsyncRead + AsyncWrite + Send, + S: AsyncRead + AsyncWrite + Send + Unpin, { - // The handshake messages all start with a 4-bytes message length prefix. - let socket = length_delimited::Builder::new() - .big_endian() - .length_field_length(4) - .new_framed(socket); + // The handshake messages all start with a variable-length integer indicating the size. + let mut socket = futures_codec::Framed::new( + socket, + unsigned_varint::codec::UviBytes::>::default() + ); - future::ok::<_, SecioError>(HandshakeContext::new(config)) - .and_then(|context| { - // Generate our nonce. - let context = context.with_local()?; - trace!("starting handshake; local nonce = {:?}", context.state.nonce); - Ok(context) - }) - .and_then(|context| { - trace!("sending proposition to remote"); - socket.send(BytesMut::from(context.state.proposition_bytes.clone())) - .from_err() - .map(|s| (s, context)) - }) - // Receive the remote's proposition. - .and_then(move |(socket, context)| { - socket.into_future() - .map_err(|(e, _)| e.into()) - .and_then(move |(prop_raw, socket)| { - let context = match prop_raw { - Some(p) => context.with_remote(p)?, - None => { - let err = IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"); - debug!("unexpected eof while waiting for remote's proposition"); - return Err(err.into()) - }, - }; - trace!("received proposition from remote; pubkey = {:?}; nonce = {:?}", - context.state.public_key, context.state.nonce); - Ok((socket, context)) - }) - }) - // Generate an ephemeral key for the negotiation. - .and_then(|(socket, context)| { - exchange::generate_agreement(context.state.chosen_exchange) - .map(move |(tmp_priv_key, tmp_pub_key)| (socket, context, tmp_priv_key, tmp_pub_key)) - }) - // Send the ephemeral pub key to the remote in an `Exchange` struct. The `Exchange` also - // contains a signature of the two propositions encoded with our static public key. - .and_then(|(socket, context, tmp_priv, tmp_pub_key)| { - let context = context.with_ephemeral(tmp_priv, tmp_pub_key.clone()); - let exchange = { - let mut data_to_sign = context.state.remote.local.proposition_bytes.clone(); - data_to_sign.extend_from_slice(&context.state.remote.proposition_bytes); - data_to_sign.extend_from_slice(&tmp_pub_key); + let local_nonce = { + let mut local_nonce = [0; 16]; + rand::thread_rng() + .try_fill_bytes(&mut local_nonce) + .map_err(|_| SecioError::NonceGenerationFailed)?; + local_nonce + }; - let mut exchange = Exchange::new(); - exchange.set_epubkey(tmp_pub_key); - match context.config.key.sign(&data_to_sign) { - Ok(sig) => exchange.set_signature(sig), - Err(_) => return Err(SecioError::SigningFailure) - } - exchange - }; - let local_exch = exchange.write_to_bytes()?; - Ok((BytesMut::from(local_exch), socket, context)) - }) - // Send our local `Exchange`. - .and_then(|(local_exch, socket, context)| { - trace!("sending exchange to remote"); - socket.send(local_exch) - .from_err() - .map(|s| (s, context)) - }) - // Receive the remote's `Exchange`. - .and_then(move |(socket, context)| { - socket.into_future() - .map_err(|(e, _)| e.into()) - .and_then(move |(raw, socket)| { - let raw = match raw { - Some(r) => r, - None => { - let err = IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"); - debug!("unexpected eof while waiting for remote's exchange"); - return Err(err.into()) - }, - }; + let local_public_key_encoded = config.key.public().into_protobuf_encoding(); - let remote_exch = match protobuf_parse_from_bytes::(&raw) { - Ok(e) => e, - Err(err) => { - debug!("failed to parse remote's exchange protobuf; {:?}", err); - return Err(SecioError::HandshakeParsingFailure); - } - }; + // Send our proposition with our nonce, public key and supported protocols. + let mut local_proposition = Propose::new(); + local_proposition.set_rand(local_nonce.to_vec()); + local_proposition.set_pubkey(local_public_key_encoded.clone()); - trace!("received and decoded the remote's exchange"); - Ok((remote_exch, socket, context)) - }) - }) - // Check the validity of the remote's `Exchange`. This verifies that the remote was really - // the sender of its proposition, and that it is the owner of both its global and ephemeral - // keys. - .and_then(|(remote_exch, socket, context)| { - let mut data_to_verify = context.state.remote.proposition_bytes.clone(); - data_to_verify.extend_from_slice(&context.state.remote.local.proposition_bytes); - data_to_verify.extend_from_slice(remote_exch.get_epubkey()); + if let Some(ref p) = config.agreements_prop { + trace!("agreements proposition: {}", p); + local_proposition.set_exchanges(p.clone()) + } else { + trace!("agreements proposition: {}", algo_support::DEFAULT_AGREEMENTS_PROPOSITION); + local_proposition.set_exchanges(algo_support::DEFAULT_AGREEMENTS_PROPOSITION.into()) + } - if !context.state.remote.public_key.verify(&data_to_verify, remote_exch.get_signature()) { - return Err(SecioError::SignatureVerificationFailed) + if let Some(ref p) = config.ciphers_prop { + trace!("ciphers proposition: {}", p); + local_proposition.set_ciphers(p.clone()) + } else { + trace!("ciphers proposition: {}", algo_support::DEFAULT_CIPHERS_PROPOSITION); + local_proposition.set_ciphers(algo_support::DEFAULT_CIPHERS_PROPOSITION.into()) + } + + if let Some(ref p) = config.digests_prop { + trace!("digests proposition: {}", p); + local_proposition.set_hashes(p.clone()) + } else { + trace!("digests proposition: {}", algo_support::DEFAULT_DIGESTS_PROPOSITION); + local_proposition.set_hashes(algo_support::DEFAULT_DIGESTS_PROPOSITION.into()) + } + + let local_proposition_bytes = local_proposition.write_to_bytes()?; + trace!("starting handshake; local nonce = {:?}", local_nonce); + + trace!("sending proposition to remote"); + socket.send(local_proposition_bytes.clone()).await?; + + // Receive the remote's proposition. + let remote_proposition_bytes = match socket.next().await { + Some(b) => b?, + None => { + let err = IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"); + debug!("unexpected eof while waiting for remote's proposition"); + return Err(err.into()) + }, + }; + + let mut remote_proposition = match protobuf_parse_from_bytes::(&remote_proposition_bytes) { + Ok(prop) => prop, + Err(_) => { + debug!("failed to parse remote's proposition protobuf message"); + return Err(SecioError::HandshakeParsingFailure); + } + }; + + let remote_public_key_encoded = remote_proposition.take_pubkey(); + let remote_nonce = remote_proposition.take_rand(); + + let remote_public_key = match PublicKey::from_protobuf_encoding(&remote_public_key_encoded) { + Ok(p) => p, + Err(_) => { + debug!("failed to parse remote's proposition's pubkey protobuf"); + return Err(SecioError::HandshakeParsingFailure); + }, + }; + trace!("received proposition from remote; pubkey = {:?}; nonce = {:?}", + remote_public_key, remote_nonce); + + // In order to determine which protocols to use, we compute two hashes and choose + // based on which hash is larger. + let hashes_ordering = { + let oh1 = { + let mut ctx = Sha256::new(); + ctx.input(&remote_public_key_encoded); + ctx.input(&local_nonce); + ctx.result() + }; + + let oh2 = { + let mut ctx = Sha256::new(); + ctx.input(&local_public_key_encoded); + ctx.input(&remote_nonce); + ctx.result() + }; + + oh1.as_ref().cmp(&oh2.as_ref()) + }; + + let chosen_exchange = { + let ours = config.agreements_prop.as_ref() + .map(|s| s.as_ref()) + .unwrap_or(algo_support::DEFAULT_AGREEMENTS_PROPOSITION); + let theirs = &remote_proposition.get_exchanges(); + match algo_support::select_agreement(hashes_ordering, ours, theirs) { + Ok(a) => a, + Err(err) => { + debug!("failed to select an exchange protocol"); + return Err(err); } + } + }; - trace!("successfully verified the remote's signature"); - Ok((remote_exch, socket, context)) - }) - // Generate a key from the local ephemeral private key and the remote ephemeral public key, - // derive from it a cipher key, an iv, and a hmac key, and build the encoder/decoder. - .and_then(|(remote_exch, socket, context)| { - let (context, local_priv_key) = context.take_private_key(); - let key_size = context.state.remote.chosen_hash.num_bytes(); - exchange::agree(context.state.remote.chosen_exchange, local_priv_key, remote_exch.get_epubkey(), key_size) - .map(move |key_material| (socket, context, key_material)) - }) - // Generate a key from the local ephemeral private key and the remote ephemeral public key, - // derive from it a cipher key, an iv, and a hmac key, and build the encoder/decoder. - .and_then(|(socket, context, key_material)| { - let chosen_cipher = context.state.remote.chosen_cipher; - let cipher_key_size = chosen_cipher.key_size(); - let iv_size = chosen_cipher.iv_size(); + let chosen_cipher = { + let ours = config.ciphers_prop.as_ref() + .map(|s| s.as_ref()) + .unwrap_or(algo_support::DEFAULT_CIPHERS_PROPOSITION); + let theirs = &remote_proposition.get_ciphers(); + match algo_support::select_cipher(hashes_ordering, ours, theirs) { + Ok(a) => { + debug!("selected cipher: {:?}", a); + a + } + Err(err) => { + debug!("failed to select a cipher protocol"); + return Err(err); + } + } + }; - let key = Hmac::from_key(context.state.remote.chosen_hash, &key_material); - let mut longer_key = vec![0u8; 2 * (iv_size + cipher_key_size + 20)]; - stretch_key(key, &mut longer_key); + let chosen_hash = { + let ours = config.digests_prop.as_ref() + .map(|s| s.as_ref()) + .unwrap_or(algo_support::DEFAULT_DIGESTS_PROPOSITION); + let theirs = &remote_proposition.get_hashes(); + match algo_support::select_digest(hashes_ordering, ours, theirs) { + Ok(a) => { + debug!("selected hash: {:?}", a); + a + } + Err(err) => { + debug!("failed to select a hash protocol"); + return Err(err); + } + } + }; - let (local_infos, remote_infos) = { - let (first_half, second_half) = longer_key.split_at(longer_key.len() / 2); - match context.state.remote.hashes_ordering { - Ordering::Equal => { - let msg = "equal digest of public key and nonce for local and remote"; - return Err(SecioError::InvalidProposition(msg)) - } - Ordering::Less => (second_half, first_half), - Ordering::Greater => (first_half, second_half), + // Generate an ephemeral key for the negotiation. + let (tmp_priv_key, tmp_pub_key) = exchange::generate_agreement(chosen_exchange).await?; + + // Send the ephemeral pub key to the remote in an `Exchange` struct. The `Exchange` also + // contains a signature of the two propositions encoded with our static public key. + let local_exchange = { + let mut data_to_sign = local_proposition_bytes.clone(); + data_to_sign.extend_from_slice(&remote_proposition_bytes); + data_to_sign.extend_from_slice(&tmp_pub_key); + + let mut exchange = Exchange::new(); + exchange.set_epubkey(tmp_pub_key.clone()); + match config.key.sign(&data_to_sign) { + Ok(sig) => exchange.set_signature(sig), + Err(_) => return Err(SecioError::SigningFailure) + } + exchange + }; + let local_exch = local_exchange.write_to_bytes()?; + + // Send our local `Exchange`. + trace!("sending exchange to remote"); + socket.send(local_exch).await?; + + // Receive the remote's `Exchange`. + let remote_exch = { + let raw = match socket.next().await { + Some(r) => r?, + None => { + let err = IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"); + debug!("unexpected eof while waiting for remote's exchange"); + return Err(err.into()) + }, + }; + + match protobuf_parse_from_bytes::(&raw) { + Ok(e) => { + trace!("received and decoded the remote's exchange"); + e + }, + Err(err) => { + debug!("failed to parse remote's exchange protobuf; {:?}", err); + return Err(SecioError::HandshakeParsingFailure); + } + } + }; + + // Check the validity of the remote's `Exchange`. This verifies that the remote was really + // the sender of its proposition, and that it is the owner of both its global and ephemeral + // keys. + { + let mut data_to_verify = remote_proposition_bytes.clone(); + data_to_verify.extend_from_slice(&local_proposition_bytes); + data_to_verify.extend_from_slice(remote_exch.get_epubkey()); + + if !remote_public_key.verify(&data_to_verify, remote_exch.get_signature()) { + return Err(SecioError::SignatureVerificationFailed) + } + + trace!("successfully verified the remote's signature"); + } + + // Generate a key from the local ephemeral private key and the remote ephemeral public key, + // derive from it a cipher key, an iv, and a hmac key, and build the encoder/decoder. + let key_material = exchange::agree(chosen_exchange, tmp_priv_key, remote_exch.get_epubkey(), chosen_hash.num_bytes()).await?; + + // Generate a key from the local ephemeral private key and the remote ephemeral public key, + // derive from it a cipher key, an iv, and a hmac key, and build the encoder/decoder. + let mut codec = { + let cipher_key_size = chosen_cipher.key_size(); + let iv_size = chosen_cipher.iv_size(); + + let key = Hmac::from_key(chosen_hash, &key_material); + let mut longer_key = vec![0u8; 2 * (iv_size + cipher_key_size + 20)]; + stretch_key(key, &mut longer_key); + + let (local_infos, remote_infos) = { + let (first_half, second_half) = longer_key.split_at(longer_key.len() / 2); + match hashes_ordering { + Ordering::Equal => { + let msg = "equal digest of public key and nonce for local and remote"; + return Err(SecioError::InvalidProposition(msg)) } - }; + Ordering::Less => (second_half, first_half), + Ordering::Greater => (first_half, second_half), + } + }; - let (encoding_cipher, encoding_hmac) = { - let (iv, rest) = local_infos.split_at(iv_size); - let (cipher_key, mac_key) = rest.split_at(cipher_key_size); - let hmac = Hmac::from_key(context.state.remote.chosen_hash, mac_key); - let cipher = ctr(chosen_cipher, cipher_key, iv); - (cipher, hmac) - }; + let (encoding_cipher, encoding_hmac) = { + let (iv, rest) = local_infos.split_at(iv_size); + let (cipher_key, mac_key) = rest.split_at(cipher_key_size); + let hmac = Hmac::from_key(chosen_hash, mac_key); + let cipher = ctr(chosen_cipher, cipher_key, iv); + (cipher, hmac) + }; - let (decoding_cipher, decoding_hmac) = { - let (iv, rest) = remote_infos.split_at(iv_size); - let (cipher_key, mac_key) = rest.split_at(cipher_key_size); - let hmac = Hmac::from_key(context.state.remote.chosen_hash, mac_key); - let cipher = ctr(chosen_cipher, cipher_key, iv); - (cipher, hmac) - }; + let (decoding_cipher, decoding_hmac) = { + let (iv, rest) = remote_infos.split_at(iv_size); + let (cipher_key, mac_key) = rest.split_at(cipher_key_size); + let hmac = Hmac::from_key(chosen_hash, mac_key); + let cipher = ctr(chosen_cipher, cipher_key, iv); + (cipher, hmac) + }; - let codec = full_codec( - socket, - encoding_cipher, - encoding_hmac, - decoding_cipher, - decoding_hmac, - context.state.remote.local.nonce.to_vec() - ); - Ok((codec, context)) - }) - // We send back their nonce to check if the connection works. - .and_then(|(codec, context)| { - let remote_nonce = context.state.remote.nonce.clone(); - trace!("checking encryption by sending back remote's nonce"); - codec.send(BytesMut::from(remote_nonce)) - .map(|s| (s, context.state.remote.public_key, context.state.local_tmp_pub_key)) - .from_err() - }) + full_codec( + socket, + encoding_cipher, + encoding_hmac, + decoding_cipher, + decoding_hmac, + local_nonce.to_vec() + ) + }; + + // We send back their nonce to check if the connection works. + trace!("checking encryption by sending back remote's nonce"); + codec.send(remote_nonce).await?; + + Ok((codec, remote_public_key, tmp_pub_key)) } /// Custom algorithm translated from reference implementations. Needs to be the same algorithm @@ -522,16 +362,10 @@ where D: ::hmac::digest::Input + ::hmac::digest::BlockInput + #[cfg(test)] mod tests { - use bytes::BytesMut; + use super::{handshake, stretch_key}; + use crate::{algo_support::Digest, codec::Hmac, SecioConfig}; use libp2p_core::identity; - use tokio::runtime::current_thread::Runtime; - use tokio_tcp::{TcpListener, TcpStream}; - use crate::{SecioConfig, SecioError}; - use super::handshake; - use super::stretch_key; - use crate::algo_support::Digest; - use crate::codec::Hmac; - use futures::prelude::*; + use futures::{prelude::*, channel::oneshot}; #[test] #[cfg(not(any(target_os = "emscripten", target_os = "unknown")))] @@ -573,38 +407,30 @@ mod tests { } fn handshake_with_self_succeeds(key1: SecioConfig, key2: SecioConfig) { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); - let listener_addr = listener.local_addr().unwrap(); + let (l_a_tx, l_a_rx) = oneshot::channel(); - let server = listener - .incoming() - .into_future() - .map_err(|(e, _)| e.into()) - .and_then(move |(connec, _)| handshake(connec.unwrap(), key1)) - .and_then(|(connec, _, _)| { - let (sink, stream) = connec.split(); - stream - .filter(|v| !v.is_empty()) - .forward(sink.with(|v| Ok::<_, SecioError>(BytesMut::from(v)))) - }); + async_std::task::spawn(async move { + let listener = async_std::net::TcpListener::bind(&"127.0.0.1:0").await.unwrap(); + l_a_tx.send(listener.local_addr().unwrap()).unwrap(); + let connec = listener.accept().await.unwrap().0; + let mut codec = handshake(connec, key1).await.unwrap().0; + while let Some(packet) = codec.next().await { + let packet = packet.unwrap(); + if !packet.is_empty() { + codec.send(packet.into()).await.unwrap(); + } + } + }); - let client = TcpStream::connect(&listener_addr) - .map_err(|e| e.into()) - .and_then(move |stream| handshake(stream, key2)) - .and_then(|(connec, _, _)| { - connec.send("hello".into()) - .from_err() - .and_then(|connec| { - connec.filter(|v| !v.is_empty()) - .into_future() - .map(|(v, _)| v) - .map_err(|(e, _)| e) - }) - .map(|v| assert_eq!(b"hello", &v.unwrap()[..])) - }); - - let mut rt = Runtime::new().unwrap(); - let _ = rt.block_on(server.join(client)).unwrap(); + futures::executor::block_on(async move { + let listen_addr = l_a_rx.await.unwrap(); + let connec = async_std::net::TcpStream::connect(&listen_addr).await.unwrap(); + let mut codec = handshake(connec, key2).await.unwrap().0; + codec.send(b"hello".to_vec().into()).await.unwrap(); + let mut packets_stream = codec.filter(|p| future::ready(!p.as_ref().unwrap().is_empty())); + let packet = packets_stream.next().await.unwrap(); + assert_eq!(packet.unwrap(), b"hello"); + }); } #[test] diff --git a/protocols/secio/src/lib.rs b/protocols/secio/src/lib.rs index 2965a921..60e55e66 100644 --- a/protocols/secio/src/lib.rs +++ b/protocols/secio/src/lib.rs @@ -29,7 +29,7 @@ //! //! ```no_run //! # fn main() { -//! use futures::Future; +//! use futures::prelude::*; //! use libp2p_secio::{SecioConfig, SecioOutput}; //! use libp2p_core::{PeerId, Multiaddr, identity}; //! use libp2p_core::transport::Transport; @@ -57,20 +57,12 @@ pub use self::error::SecioError; -use bytes::BytesMut; use futures::stream::MapErr as StreamMapErr; -use futures::{Future, Poll, Sink, StartSend, Stream}; -use libp2p_core::{ - PeerId, - PublicKey, - identity, - upgrade::{UpgradeInfo, InboundUpgrade, OutboundUpgrade, Negotiated} -}; +use futures::{prelude::*, io::Initializer}; +use libp2p_core::{PeerId, PublicKey, identity, upgrade::{UpgradeInfo, InboundUpgrade, OutboundUpgrade, Negotiated}}; use log::debug; use rw_stream_sink::RwStreamSink; -use std::io; -use std::iter; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{io, iter, pin::Pin, task::Context, task::Poll}; mod algo_support; mod codec; @@ -134,13 +126,13 @@ impl SecioConfig { self } - fn handshake(self, socket: T) -> impl Future), Error=SecioError> + fn handshake(self, socket: T) -> impl Future), SecioError>> where - T: AsyncRead + AsyncWrite + Send + 'static + T: AsyncRead + AsyncWrite + Unpin + Send + 'static { debug!("Starting secio upgrade"); SecioMiddleware::handshake(socket, self) - .map(|(stream_sink, pubkey, ephemeral)| { + .map_ok(|(stream_sink, pubkey, ephemeral)| { let mapped = stream_sink.map_err(map_err as fn(_) -> _); let peer = pubkey.clone().into_peer_id(); let io = SecioOutput { @@ -177,55 +169,59 @@ impl UpgradeInfo for SecioConfig { impl InboundUpgrade for SecioConfig where - T: AsyncRead + AsyncWrite + Send + 'static + T: AsyncRead + AsyncWrite + Unpin + Send + 'static { type Output = (PeerId, SecioOutput>); type Error = SecioError; - type Future = Box + Send>; + type Future = Pin> + Send>>; fn upgrade_inbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { - Box::new(self.handshake(socket)) + Box::pin(self.handshake(socket)) } } impl OutboundUpgrade for SecioConfig where - T: AsyncRead + AsyncWrite + Send + 'static + T: AsyncRead + AsyncWrite + Unpin + Send + 'static { type Output = (PeerId, SecioOutput>); type Error = SecioError; - type Future = Box + Send>; + type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: Negotiated, _: Self::Info) -> Self::Future { - Box::new(self.handshake(socket)) + Box::pin(self.handshake(socket)) } } -impl io::Read for SecioOutput { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.stream.read(buf) +impl AsyncRead for SecioOutput { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) + -> Poll> + { + AsyncRead::poll_read(Pin::new(&mut self.stream), cx, buf) + } + + unsafe fn initializer(&self) -> Initializer { + self.stream.initializer() } } -impl AsyncRead for SecioOutput { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.stream.prepare_uninitialized_buffer(buf) - } -} - -impl io::Write for SecioOutput { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.stream.write(buf) +impl AsyncWrite for SecioOutput { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) + -> Poll> + { + AsyncWrite::poll_write(Pin::new(&mut self.stream), cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.stream.flush() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) + -> Poll> + { + AsyncWrite::poll_flush(Pin::new(&mut self.stream), cx) } -} -impl AsyncWrite for SecioOutput { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.stream.shutdown() + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) + -> Poll> + { + AsyncWrite::poll_close(Pin::new(&mut self.stream), cx) } } @@ -244,54 +240,52 @@ pub struct SecioMiddleware { impl SecioMiddleware where - S: AsyncRead + AsyncWrite + Send, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, { /// Attempts to perform a handshake on the given socket. /// /// On success, produces a `SecioMiddleware` that can then be used to encode/decode /// communications, plus the public key of the remote, plus the ephemeral public key. pub fn handshake(socket: S, config: SecioConfig) - -> impl Future, PublicKey, Vec), Error = SecioError> + -> impl Future, PublicKey, Vec), SecioError>> { - handshake::handshake(socket, config).map(|(inner, pubkey, ephemeral)| { + handshake::handshake(socket, config).map_ok(|(inner, pubkey, ephemeral)| { let inner = SecioMiddleware { inner }; (inner, pubkey, ephemeral) }) } } -impl Sink for SecioMiddleware +impl Sink> for SecioMiddleware where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { - type SinkItem = BytesMut; - type SinkError = io::Error; + type Error = io::Error; - #[inline] - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.inner.start_send(item) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_ready(Pin::new(&mut self.inner), cx) } - #[inline] - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.inner.poll_complete() + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + Sink::start_send(Pin::new(&mut self.inner), item) } - #[inline] - fn close(&mut self) -> Poll<(), Self::SinkError> { - self.inner.close() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_flush(Pin::new(&mut self.inner), cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_close(Pin::new(&mut self.inner), cx) } } impl Stream for SecioMiddleware where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { - type Item = Vec; - type Error = SecioError; + type Item = Result, SecioError>; - #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - self.inner.poll() + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Stream::poll_next(Pin::new(&mut self.inner), cx) } } diff --git a/swarm/Cargo.toml b/swarm/Cargo.toml index a20f9fcb..a1ccfeb3 100644 --- a/swarm/Cargo.toml +++ b/swarm/Cargo.toml @@ -10,15 +10,13 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] -futures = "0.1" +futures-preview = "0.3.0-alpha.17" libp2p-core = { version = "0.12.0", path = "../core" } smallvec = "0.6" -tokio-io = "0.1" -wasm-timer = "0.1" +wasm-timer = "0.2" void = "1" [dev-dependencies] libp2p-mplex = { version = "0.12.0", path = "../muxers/mplex" } quickcheck = "0.8" rand = "0.6" - diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index e3d72490..aca59112 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -20,8 +20,7 @@ use crate::protocols_handler::{IntoProtocolsHandler, ProtocolsHandler}; use libp2p_core::{ConnectedPoint, Multiaddr, PeerId, nodes::ListenerId}; -use futures::prelude::*; -use std::error; +use std::{error, task::Context, task::Poll}; /// A behaviour for the network. Allows customizing the swarm. /// @@ -133,8 +132,8 @@ pub trait NetworkBehaviour { /// /// This API mimics the API of the `Stream` trait. The method may register the current task in /// order to wake it up at a later point in time. - fn poll(&mut self, params: &mut impl PollParameters) - -> Async::Handler as ProtocolsHandler>::InEvent, Self::OutEvent>>; + fn poll(&mut self, cx: &mut Context, params: &mut impl PollParameters) + -> Poll::Handler as ProtocolsHandler>::InEvent, Self::OutEvent>>; } /// Parameters passed to `poll()`, that the `NetworkBehaviour` has access to. diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index fd49bdb7..1c455269 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -93,7 +93,7 @@ use libp2p_core::{ }; use registry::{Addresses, AddressIntoIter}; use smallvec::SmallVec; -use std::{error, fmt, io, ops::{Deref, DerefMut}}; +use std::{error, fmt, io, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}}; use std::collections::HashSet; /// Contains the state of the network, plus the way it should behave. @@ -140,14 +140,7 @@ where banned_peers: HashSet, /// Pending event message to be delivered. - /// - /// If the pair's second element is `AsyncSink::NotReady`, the event - /// message has yet to be sent using `PeerMut::start_send_event`. - /// - /// If the pair's second element is `AsyncSink::Ready`, the event - /// message has been sent and needs to be flushed using - /// `PeerMut::complete_send_event`. - send_event_to_complete: Option<(PeerId, AsyncSink)> + send_event_to_complete: Option<(PeerId, TInEvent)> } impl Deref for @@ -172,6 +165,13 @@ where } } +impl Unpin for + ExpandedSwarm +where + TTransport: Transport, +{ +} + impl ExpandedSwarm where TBehaviour: NetworkBehaviour, @@ -180,9 +180,9 @@ where TBehaviour: NetworkBehaviour, ::Substream: Send + 'static, TTransport: Transport + Clone, TTransport::Error: Send + 'static, - TTransport::Listener: Send + 'static, - TTransport::ListenerUpgrade: Send + 'static, - TTransport::Dial: Send + 'static, + TTransport::Listener: Unpin + Send + 'static, + TTransport::ListenerUpgrade: Unpin + Send + 'static, + TTransport::Dial: Unpin + Send + 'static, THandlerErr: error::Error, THandler: IntoProtocolsHandler + Send + 'static, ::Handler: ProtocolsHandler, Error = THandlerErr> + Send + 'static, @@ -315,9 +315,9 @@ where TBehaviour: NetworkBehaviour, ::Substream: Send + 'static, TTransport: Transport + Clone, TTransport::Error: Send + 'static, - TTransport::Listener: Send + 'static, - TTransport::ListenerUpgrade: Send + 'static, - TTransport::Dial: Send + 'static, + TTransport::Listener: Unpin + Send + 'static, + TTransport::ListenerUpgrade: Unpin + Send + 'static, + TTransport::Dial: Unpin + Send + 'static, THandlerErr: error::Error, THandler: IntoProtocolsHandler + Send + 'static, ::Handler: ProtocolsHandler, Error = THandlerErr> + Send + 'static, @@ -340,123 +340,122 @@ where TBehaviour: NetworkBehaviour, ::Handler> as NodeHandler>::OutboundOpenInfo: Send + 'static, // TODO: shouldn't be necessary TConnInfo: ConnectionInfo + fmt::Debug + Clone + Send + 'static, { - type Item = TBehaviour::OutEvent; - type Error = io::Error; + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // We use a `this` variable to solve borrowing issues. + let this = &mut *self; - fn poll(&mut self) -> Poll, io::Error> { loop { let mut network_not_ready = false; - match self.network.poll() { - Async::NotReady => network_not_ready = true, - Async::Ready(NetworkEvent::NodeEvent { conn_info, event }) => { - self.behaviour.inject_node_event(conn_info.peer_id().clone(), event); + match this.network.poll(cx) { + Poll::Pending => network_not_ready = true, + Poll::Ready(NetworkEvent::NodeEvent { conn_info, event }) => { + this.behaviour.inject_node_event(conn_info.peer_id().clone(), event); }, - Async::Ready(NetworkEvent::Connected { conn_info, endpoint }) => { - if self.banned_peers.contains(conn_info.peer_id()) { - self.network.peer(conn_info.peer_id().clone()) + Poll::Ready(NetworkEvent::Connected { conn_info, endpoint }) => { + if this.banned_peers.contains(conn_info.peer_id()) { + this.network.peer(conn_info.peer_id().clone()) .into_connected() .expect("the Network just notified us that we were connected; QED") .close(); } else { - self.behaviour.inject_connected(conn_info.peer_id().clone(), endpoint); + this.behaviour.inject_connected(conn_info.peer_id().clone(), endpoint); } }, - Async::Ready(NetworkEvent::NodeClosed { conn_info, endpoint, .. }) => { - self.behaviour.inject_disconnected(conn_info.peer_id(), endpoint); + Poll::Ready(NetworkEvent::NodeClosed { conn_info, endpoint, .. }) => { + this.behaviour.inject_disconnected(conn_info.peer_id(), endpoint); }, - Async::Ready(NetworkEvent::Replaced { new_info, closed_endpoint, endpoint, .. }) => { - self.behaviour.inject_replaced(new_info.peer_id().clone(), closed_endpoint, endpoint); + Poll::Ready(NetworkEvent::Replaced { new_info, closed_endpoint, endpoint, .. }) => { + this.behaviour.inject_replaced(new_info.peer_id().clone(), closed_endpoint, endpoint); }, - Async::Ready(NetworkEvent::IncomingConnection(incoming)) => { - let handler = self.behaviour.new_handler(); + Poll::Ready(NetworkEvent::IncomingConnection(incoming)) => { + let handler = this.behaviour.new_handler(); incoming.accept(handler.into_node_handler_builder()); }, - Async::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { - if !self.listened_addrs.contains(&listen_addr) { - self.listened_addrs.push(listen_addr.clone()) + Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { + if !this.listened_addrs.contains(&listen_addr) { + this.listened_addrs.push(listen_addr.clone()) } - self.behaviour.inject_new_listen_addr(&listen_addr); + this.behaviour.inject_new_listen_addr(&listen_addr); } - Async::Ready(NetworkEvent::ExpiredListenerAddress { listen_addr, .. }) => { - self.listened_addrs.retain(|a| a != &listen_addr); - self.behaviour.inject_expired_listen_addr(&listen_addr); + Poll::Ready(NetworkEvent::ExpiredListenerAddress { listen_addr, .. }) => { + this.listened_addrs.retain(|a| a != &listen_addr); + this.behaviour.inject_expired_listen_addr(&listen_addr); } - Async::Ready(NetworkEvent::ListenerClosed { listener_id, .. }) => - self.behaviour.inject_listener_closed(listener_id), - Async::Ready(NetworkEvent::ListenerError { listener_id, error }) => - self.behaviour.inject_listener_error(listener_id, &error), - Async::Ready(NetworkEvent::IncomingConnectionError { .. }) => {}, - Async::Ready(NetworkEvent::DialError { peer_id, multiaddr, error, new_state }) => { - self.behaviour.inject_addr_reach_failure(Some(&peer_id), &multiaddr, &error); + Poll::Ready(NetworkEvent::ListenerClosed { listener_id, .. }) => + this.behaviour.inject_listener_closed(listener_id), + Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) => + this.behaviour.inject_listener_error(listener_id, &error), + Poll::Ready(NetworkEvent::IncomingConnectionError { .. }) => {}, + Poll::Ready(NetworkEvent::DialError { peer_id, multiaddr, error, new_state }) => { + this.behaviour.inject_addr_reach_failure(Some(&peer_id), &multiaddr, &error); if let network::PeerState::NotConnected = new_state { - self.behaviour.inject_dial_failure(&peer_id); + this.behaviour.inject_dial_failure(&peer_id); } }, - Async::Ready(NetworkEvent::UnknownPeerDialError { multiaddr, error, .. }) => { - self.behaviour.inject_addr_reach_failure(None, &multiaddr, &error); + Poll::Ready(NetworkEvent::UnknownPeerDialError { multiaddr, error, .. }) => { + this.behaviour.inject_addr_reach_failure(None, &multiaddr, &error); }, } // Try to deliver pending event. - if let Some((id, pending)) = self.send_event_to_complete.take() { - if let Some(mut peer) = self.network.peer(id.clone()).into_connected() { - if let AsyncSink::NotReady(e) = pending { - if let Ok(a@AsyncSink::NotReady(_)) = peer.start_send_event(e) { - self.send_event_to_complete = Some((id, a)) - } else if let Ok(Async::NotReady) = peer.complete_send_event() { - self.send_event_to_complete = Some((id, AsyncSink::Ready)) - } - } else if let Ok(Async::NotReady) = peer.complete_send_event() { - self.send_event_to_complete = Some((id, AsyncSink::Ready)) + if let Some((id, pending)) = this.send_event_to_complete.take() { + if let Some(mut peer) = this.network.peer(id.clone()).into_connected() { + match peer.poll_ready_event(cx) { + Poll::Ready(()) => peer.start_send_event(pending), + Poll::Pending => { + this.send_event_to_complete = Some((id, pending)); + return Poll::Pending + }, } } } - if self.send_event_to_complete.is_some() { - return Ok(Async::NotReady) - } let behaviour_poll = { let mut parameters = SwarmPollParameters { - local_peer_id: &mut self.network.local_peer_id(), - supported_protocols: &self.supported_protocols, - listened_addrs: &self.listened_addrs, - external_addrs: &self.external_addrs + local_peer_id: &mut this.network.local_peer_id(), + supported_protocols: &this.supported_protocols, + listened_addrs: &this.listened_addrs, + external_addrs: &this.external_addrs }; - self.behaviour.poll(&mut parameters) + this.behaviour.poll(cx, &mut parameters) }; match behaviour_poll { - Async::NotReady if network_not_ready => return Ok(Async::NotReady), - Async::NotReady => (), - Async::Ready(NetworkBehaviourAction::GenerateEvent(event)) => { - return Ok(Async::Ready(Some(event))) + Poll::Pending if network_not_ready => return Poll::Pending, + Poll::Pending => (), + Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) => { + return Poll::Ready(Some(Ok(event))) }, - Async::Ready(NetworkBehaviourAction::DialAddress { address }) => { - let _ = ExpandedSwarm::dial_addr(self, address); + Poll::Ready(NetworkBehaviourAction::DialAddress { address }) => { + let _ = ExpandedSwarm::dial_addr(&mut *this, address); }, - Async::Ready(NetworkBehaviourAction::DialPeer { peer_id }) => { - if self.banned_peers.contains(&peer_id) { - self.behaviour.inject_dial_failure(&peer_id); + Poll::Ready(NetworkBehaviourAction::DialPeer { peer_id }) => { + if this.banned_peers.contains(&peer_id) { + this.behaviour.inject_dial_failure(&peer_id); } else { - ExpandedSwarm::dial(self, peer_id); + ExpandedSwarm::dial(&mut *this, peer_id); } }, - Async::Ready(NetworkBehaviourAction::SendEvent { peer_id, event }) => { - if let Some(mut peer) = self.network.peer(peer_id.clone()).into_connected() { - if let Ok(a@AsyncSink::NotReady(_)) = peer.start_send_event(event) { - self.send_event_to_complete = Some((peer_id, a)) - } else if let Ok(Async::NotReady) = peer.complete_send_event() { - self.send_event_to_complete = Some((peer_id, AsyncSink::Ready)) + Poll::Ready(NetworkBehaviourAction::SendEvent { peer_id, event }) => { + if let Some(mut peer) = this.network.peer(peer_id.clone()).into_connected() { + if let Poll::Ready(()) = peer.poll_ready_event(cx) { + peer.start_send_event(event); + } else { + debug_assert!(this.send_event_to_complete.is_none()); + this.send_event_to_complete = Some((peer_id, event)); + return Poll::Pending; } } }, - Async::Ready(NetworkBehaviourAction::ReportObservedAddr { address }) => { - for addr in self.network.address_translation(&address) { - if self.external_addrs.iter().all(|a| *a != addr) { - self.behaviour.inject_new_external_addr(&addr); + Poll::Ready(NetworkBehaviourAction::ReportObservedAddr { address }) => { + for addr in this.network.address_translation(&address) { + if this.external_addrs.iter().all(|a| *a != addr) { + this.behaviour.inject_new_external_addr(&addr); } - self.external_addrs.add(addr) + this.external_addrs.add(addr) } }, } @@ -509,9 +508,9 @@ where TBehaviour: NetworkBehaviour, ::Substream: Send + 'static, TTransport: Transport + Clone, TTransport::Error: Send + 'static, - TTransport::Listener: Send + 'static, - TTransport::ListenerUpgrade: Send + 'static, - TTransport::Dial: Send + 'static, + TTransport::Listener: Unpin + Send + 'static, + TTransport::ListenerUpgrade: Unpin + Send + 'static, + TTransport::Dial: Unpin + Send + 'static, ::ProtocolsHandler: Send + 'static, <::ProtocolsHandler as IntoProtocolsHandler>::Handler: ProtocolsHandler> + Send + 'static, <<::ProtocolsHandler as IntoProtocolsHandler>::Handler as ProtocolsHandler>::InEvent: Send + 'static, @@ -584,8 +583,7 @@ mod tests { }; use libp2p_mplex::Multiplex; use futures::prelude::*; - use std::marker::PhantomData; - use tokio_io::{AsyncRead, AsyncWrite}; + use std::{marker::PhantomData, task::Context, task::Poll}; use void::Void; #[derive(Clone)] @@ -593,11 +591,9 @@ mod tests { marker: PhantomData, } - trait TSubstream: AsyncRead + AsyncWrite {} - impl NetworkBehaviour for DummyBehaviour - where TSubstream: AsyncRead + AsyncWrite + where TSubstream: AsyncRead + AsyncWrite + Unpin { type ProtocolsHandler = DummyProtocolsHandler; type OutEvent = Void; @@ -617,11 +613,11 @@ mod tests { fn inject_node_event(&mut self, _: PeerId, _: ::OutEvent) {} - fn poll(&mut self, _: &mut impl PollParameters) -> - Async + Poll::InEvent, Self::OutEvent>> { - Async::NotReady + Poll::Pending } } diff --git a/swarm/src/protocols_handler/dummy.rs b/swarm/src/protocols_handler/dummy.rs index a9719b85..f3c6052d 100644 --- a/swarm/src/protocols_handler/dummy.rs +++ b/swarm/src/protocols_handler/dummy.rs @@ -27,8 +27,7 @@ use crate::protocols_handler::{ }; use futures::prelude::*; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, DeniedUpgrade}; -use std::marker::PhantomData; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{marker::PhantomData, task::Context, task::Poll}; use void::Void; /// Implementation of `ProtocolsHandler` that doesn't handle anything. @@ -47,7 +46,7 @@ impl Default for DummyProtocolsHandler { impl ProtocolsHandler for DummyProtocolsHandler where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin, { type InEvent = Void; type OutEvent = Void; @@ -89,10 +88,10 @@ where #[inline] fn poll( &mut self, + _: &mut Context, ) -> Poll< - ProtocolsHandlerEvent, - Void, + ProtocolsHandlerEvent, > { - Ok(Async::NotReady) + Poll::Pending } } diff --git a/swarm/src/protocols_handler/map_in.rs b/swarm/src/protocols_handler/map_in.rs index e478e58f..dedae4a9 100644 --- a/swarm/src/protocols_handler/map_in.rs +++ b/swarm/src/protocols_handler/map_in.rs @@ -25,9 +25,8 @@ use crate::protocols_handler::{ ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr }; -use futures::prelude::*; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade}; -use std::marker::PhantomData; +use std::{marker::PhantomData, task::Context, task::Poll}; /// Wrapper around a protocol handler that turns the input event into something else. pub struct MapInEvent { @@ -103,10 +102,10 @@ where #[inline] fn poll( &mut self, + cx: &mut Context, ) -> Poll< - ProtocolsHandlerEvent, - Self::Error, + ProtocolsHandlerEvent, > { - self.inner.poll() + self.inner.poll(cx) } } diff --git a/swarm/src/protocols_handler/map_out.rs b/swarm/src/protocols_handler/map_out.rs index 5815d949..4bc04791 100644 --- a/swarm/src/protocols_handler/map_out.rs +++ b/swarm/src/protocols_handler/map_out.rs @@ -25,8 +25,8 @@ use crate::protocols_handler::{ ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr }; -use futures::prelude::*; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade}; +use std::task::{Context, Poll}; /// Wrapper around a protocol handler that turns the output event into something else. pub struct MapOutEvent { @@ -98,17 +98,18 @@ where #[inline] fn poll( &mut self, + cx: &mut Context, ) -> Poll< - ProtocolsHandlerEvent, - Self::Error, + ProtocolsHandlerEvent, > { - Ok(self.inner.poll()?.map(|ev| { + self.inner.poll(cx).map(|ev| { match ev { ProtocolsHandlerEvent::Custom(ev) => ProtocolsHandlerEvent::Custom((self.map)(ev)), + ProtocolsHandlerEvent::Close(err) => ProtocolsHandlerEvent::Close(err), ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info } => { ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info } } } - })) + }) } } diff --git a/swarm/src/protocols_handler/mod.rs b/swarm/src/protocols_handler/mod.rs index 855d95d4..8b7dbe71 100644 --- a/swarm/src/protocols_handler/mod.rs +++ b/swarm/src/protocols_handler/mod.rs @@ -50,8 +50,7 @@ use libp2p_core::{ PeerId, upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError}, }; -use std::{cmp::Ordering, error, fmt, time::Duration}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{cmp::Ordering, error, fmt, task::Context, task::Poll, time::Duration}; use wasm_timer::Instant; pub use dummy::DummyProtocolsHandler; @@ -101,7 +100,7 @@ pub trait ProtocolsHandler { /// The type of errors returned by [`ProtocolsHandler::poll`]. type Error: error::Error; /// The type of substreams on which the protocol(s) are negotiated. - type Substream: AsyncRead + AsyncWrite; + type Substream: AsyncRead + AsyncWrite + Unpin; /// The inbound upgrade for the protocol(s) used by the handler. type InboundProtocol: InboundUpgrade; /// The outbound upgrade for the protocol(s) used by the handler. @@ -171,9 +170,8 @@ pub trait ProtocolsHandler { /// Should behave like `Stream::poll()`. /// /// Returning an error will close the connection to the remote. - fn poll(&mut self) -> Poll< - ProtocolsHandlerEvent, - Self::Error + fn poll(&mut self, cx: &mut Context) -> Poll< + ProtocolsHandlerEvent >; /// Adds a closure that turns the input event into something else. @@ -300,7 +298,7 @@ impl From for SubstreamProtocol { /// Event produced by a handler. #[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum ProtocolsHandlerEvent { +pub enum ProtocolsHandlerEvent { /// Request a new outbound substream to be opened with the remote. OutboundSubstreamRequest { /// The protocol(s) to apply on the substream. @@ -309,13 +307,16 @@ pub enum ProtocolsHandlerEvent { info: TOutboundOpenInfo, }, + /// Close the connection for the given reason. + Close(TErr), + /// Other event. Custom(TCustom), } /// Event produced by a handler. -impl - ProtocolsHandlerEvent +impl + ProtocolsHandlerEvent { /// If this is an `OutboundSubstreamRequest`, maps the `info` member from a /// `TOutboundOpenInfo` to something else. @@ -323,7 +324,7 @@ impl pub fn map_outbound_open_info( self, map: F, - ) -> ProtocolsHandlerEvent + ) -> ProtocolsHandlerEvent where F: FnOnce(TOutboundOpenInfo) -> I, { @@ -335,6 +336,7 @@ impl } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), + ProtocolsHandlerEvent::Close(val) => ProtocolsHandlerEvent::Close(val), } } @@ -344,7 +346,7 @@ impl pub fn map_protocol( self, map: F, - ) -> ProtocolsHandlerEvent + ) -> ProtocolsHandlerEvent where F: FnOnce(TConnectionUpgrade) -> I, { @@ -356,6 +358,7 @@ impl } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), + ProtocolsHandlerEvent::Close(val) => ProtocolsHandlerEvent::Close(val), } } @@ -364,7 +367,7 @@ impl pub fn map_custom( self, map: F, - ) -> ProtocolsHandlerEvent + ) -> ProtocolsHandlerEvent where F: FnOnce(TCustom) -> I, { @@ -373,6 +376,25 @@ impl ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(map(val)), + ProtocolsHandlerEvent::Close(val) => ProtocolsHandlerEvent::Close(val), + } + } + + /// If this is a `Close` event, maps the content to something else. + #[inline] + pub fn map_close( + self, + map: F, + ) -> ProtocolsHandlerEvent + where + F: FnOnce(TErr) -> I, + { + match self { + ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info } => { + ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info } + } + ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), + ProtocolsHandlerEvent::Close(val) => ProtocolsHandlerEvent::Close(map(val)), } } } diff --git a/swarm/src/protocols_handler/node_handler.rs b/swarm/src/protocols_handler/node_handler.rs index 14b2e01c..289aa05b 100644 --- a/swarm/src/protocols_handler/node_handler.rs +++ b/swarm/src/protocols_handler/node_handler.rs @@ -33,8 +33,8 @@ use libp2p_core::{ nodes::handled_node::{IntoNodeHandler, NodeHandler, NodeHandlerEndpoint, NodeHandlerEvent}, upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply} }; -use std::{error, fmt, time::Duration}; -use wasm_timer::{Delay, Timeout}; +use std::{error, fmt, pin::Pin, task::Context, task::Poll, time::Duration}; +use wasm_timer::{Delay, Instant}; /// Prototype for a `NodeHandlerWrapper`. pub struct NodeHandlerWrapperBuilder { @@ -102,12 +102,13 @@ where handler: TProtoHandler, /// Futures that upgrade incoming substreams. negotiating_in: - Vec>>, + Vec<(InboundUpgradeApply, Delay)>, /// Futures that upgrade outgoing substreams. The first element of the tuple is the userdata /// to pass back once successfully opened. negotiating_out: Vec<( TProtoHandler::OutboundOpenInfo, - Timeout>, + OutboundUpgradeApply, + Delay, )>, /// For each outbound substream request, how to upgrade it. The first element of the tuple /// is the unique identifier (see `unique_dial_upgrade_id`). @@ -133,7 +134,7 @@ enum Shutdown { /// A shut down is planned as soon as possible. Asap, /// A shut down is planned for when a `Delay` has elapsed. - Later(Delay) + Later(Delay, Instant) } /// Error generated by the `NodeHandlerWrapper`. @@ -198,8 +199,8 @@ where let protocol = self.handler.listen_protocol(); let timeout = protocol.timeout().clone(); let upgrade = upgrade::apply_inbound(substream, protocol.into_upgrade()); - let with_timeout = Timeout::new(upgrade, timeout); - self.negotiating_in.push(with_timeout); + let timeout = Delay::new(timeout); + self.negotiating_in.push((upgrade, timeout)); } NodeHandlerEndpoint::Dialer((upgrade_id, user_data, timeout)) => { let pos = match self @@ -216,8 +217,8 @@ where let (_, proto_upgrade) = self.queued_dial_upgrades.remove(pos); let upgrade = upgrade::apply_outbound(substream, proto_upgrade); - let with_timeout = Timeout::new(upgrade, timeout); - self.negotiating_out.push((user_data, with_timeout)); + let timeout = Delay::new(timeout); + self.negotiating_out.push((user_data, upgrade, timeout)); } } } @@ -227,44 +228,50 @@ where self.handler.inject_event(event); } - fn poll(&mut self) -> Poll, Self::Error> { + fn poll(&mut self, cx: &mut Context) -> Poll, Self::Error>> { // Continue negotiation of newly-opened substreams on the listening side. // We remove each element from `negotiating_in` one by one and add them back if not ready. for n in (0..self.negotiating_in.len()).rev() { - let mut in_progress = self.negotiating_in.swap_remove(n); - match in_progress.poll() { - Ok(Async::Ready(upgrade)) => + let (mut in_progress, mut timeout) = self.negotiating_in.swap_remove(n); + match Future::poll(Pin::new(&mut timeout), cx) { + Poll::Ready(_) => continue, + Poll::Pending => {}, + } + match Future::poll(Pin::new(&mut in_progress), cx) { + Poll::Ready(Ok(upgrade)) => self.handler.inject_fully_negotiated_inbound(upgrade), - Ok(Async::NotReady) => self.negotiating_in.push(in_progress), + Poll::Pending => self.negotiating_in.push((in_progress, timeout)), // TODO: return a diagnostic event? - Err(_err) => {} + Poll::Ready(Err(_err)) => {} } } // Continue negotiation of newly-opened substreams. // We remove each element from `negotiating_out` one by one and add them back if not ready. for n in (0..self.negotiating_out.len()).rev() { - let (upgr_info, mut in_progress) = self.negotiating_out.swap_remove(n); - match in_progress.poll() { - Ok(Async::Ready(upgrade)) => { + let (upgr_info, mut in_progress, mut timeout) = self.negotiating_out.swap_remove(n); + match Future::poll(Pin::new(&mut timeout), cx) { + Poll::Ready(Ok(_)) => { + let err = ProtocolsHandlerUpgrErr::Timeout; + self.handler.inject_dial_upgrade_error(upgr_info, err); + continue; + }, + Poll::Ready(Err(_)) => { + let err = ProtocolsHandlerUpgrErr::Timer; + self.handler.inject_dial_upgrade_error(upgr_info, err); + continue; + }, + Poll::Pending => {}, + } + match Future::poll(Pin::new(&mut in_progress), cx) { + Poll::Ready(Ok(upgrade)) => { self.handler.inject_fully_negotiated_outbound(upgrade, upgr_info); } - Ok(Async::NotReady) => { - self.negotiating_out.push((upgr_info, in_progress)); + Poll::Pending => { + self.negotiating_out.push((upgr_info, in_progress, timeout)); } - Err(err) => { - let err = if err.is_elapsed() { - ProtocolsHandlerUpgrErr::Timeout - } else if err.is_timer() { - ProtocolsHandlerUpgrErr::Timer - } else { - debug_assert!(err.is_inner()); - let err = err.into_inner().expect("Timeout error is one of {elapsed, \ - timer, inner}; is_elapsed and is_timer are both false; error is \ - inner; QED"); - ProtocolsHandlerUpgrErr::Upgrade(err) - }; - + Poll::Ready(Err(err)) => { + let err = ProtocolsHandlerUpgrErr::Upgrade(err); self.handler.inject_dial_upgrade_error(upgr_info, err); } } @@ -272,25 +279,26 @@ where // Poll the handler at the end so that we see the consequences of the method // calls on `self.handler`. - let poll_result = self.handler.poll()?; + let poll_result = self.handler.poll(cx); // Ask the handler whether it wants the connection (and the handler itself) // to be kept alive, which determines the planned shutdown, if any. match (&mut self.shutdown, self.handler.connection_keep_alive()) { - (Shutdown::Later(d), KeepAlive::Until(t)) => - if d.deadline() != t { - d.reset(t) + (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => + if *deadline != t { + *deadline = t; + timer.reset_at(t) }, - (_, KeepAlive::Until(t)) => self.shutdown = Shutdown::Later(Delay::new(t)), + (_, KeepAlive::Until(t)) => self.shutdown = Shutdown::Later(Delay::new_at(t), t), (_, KeepAlive::No) => self.shutdown = Shutdown::Asap, (_, KeepAlive::Yes) => self.shutdown = Shutdown::None }; match poll_result { - Async::Ready(ProtocolsHandlerEvent::Custom(event)) => { - return Ok(Async::Ready(NodeHandlerEvent::Custom(event))); + Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { + return Poll::Ready(Ok(NodeHandlerEvent::Custom(event))); } - Async::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info, }) => { @@ -298,11 +306,12 @@ where let timeout = protocol.timeout().clone(); self.unique_dial_upgrade_id += 1; self.queued_dial_upgrades.push((id, protocol.into_upgrade())); - return Ok(Async::Ready( + return Poll::Ready(Ok( NodeHandlerEvent::OutboundSubstreamRequest((id, info, timeout)), )); } - Async::NotReady => (), + Poll::Ready(ProtocolsHandlerEvent::Close(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => (), }; // Check if the connection (and handler) should be shut down. @@ -310,15 +319,14 @@ where if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { match self.shutdown { Shutdown::None => {}, - Shutdown::Asap => return Err(NodeHandlerWrapperError::UselessTimeout), - Shutdown::Later(ref mut delay) => match delay.poll() { - Ok(Async::Ready(_)) | Err(_) => - return Err(NodeHandlerWrapperError::UselessTimeout), - Ok(Async::NotReady) => {} + Shutdown::Asap => return Poll::Ready(Err(NodeHandlerWrapperError::UselessTimeout)), + Shutdown::Later(ref mut delay, _) => match Future::poll(Pin::new(delay), cx) { + Poll::Ready(_) => return Poll::Ready(Err(NodeHandlerWrapperError::UselessTimeout)), + Poll::Pending => {} } } } - Ok(Async::NotReady) + Poll::Pending } } diff --git a/swarm/src/protocols_handler/one_shot.rs b/swarm/src/protocols_handler/one_shot.rs index c685dfb9..40da87d0 100644 --- a/swarm/src/protocols_handler/one_shot.rs +++ b/swarm/src/protocols_handler/one_shot.rs @@ -28,8 +28,7 @@ use crate::protocols_handler::{ use futures::prelude::*; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade}; use smallvec::SmallVec; -use std::{error, marker::PhantomData, time::Duration}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{error, marker::PhantomData, task::Context, task::Poll, time::Duration}; use wasm_timer::Instant; /// Implementation of `ProtocolsHandler` that opens a new substream for each individual message. @@ -132,7 +131,7 @@ where impl ProtocolsHandler for OneShotHandler where - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin, TInProto: InboundUpgrade, TOutProto: OutboundUpgrade, TInProto::Output: Into, @@ -208,18 +207,18 @@ where fn poll( &mut self, + _: &mut Context, ) -> Poll< - ProtocolsHandlerEvent, - Self::Error, + ProtocolsHandlerEvent, > { if let Some(err) = self.pending_error.take() { - return Err(err); + return Poll::Ready(ProtocolsHandlerEvent::Close(err)); } if !self.events_out.is_empty() { - return Ok(Async::Ready(ProtocolsHandlerEvent::Custom( + return Poll::Ready(ProtocolsHandlerEvent::Custom( self.events_out.remove(0), - ))); + )); } else { self.events_out.shrink_to_fit(); } @@ -227,17 +226,17 @@ where if !self.dial_queue.is_empty() { if self.dial_negotiated < self.max_dial_negotiated { self.dial_negotiated += 1; - return Ok(Async::Ready( + return Poll::Ready( ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new(self.dial_queue.remove(0)), info: (), }, - )); + ); } } else { self.dial_queue.shrink_to_fit(); } - Ok(Async::NotReady) + Poll::Pending } } diff --git a/swarm/src/protocols_handler/select.rs b/swarm/src/protocols_handler/select.rs index 074920b1..f030fbe5 100644 --- a/swarm/src/protocols_handler/select.rs +++ b/swarm/src/protocols_handler/select.rs @@ -33,8 +33,7 @@ use libp2p_core::{ either::{EitherError, EitherOutput}, upgrade::{InboundUpgrade, OutboundUpgrade, EitherUpgrade, SelectUpgrade, UpgradeError} }; -use std::cmp; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{cmp, task::Context, task::Poll}; /// Implementation of `IntoProtocolsHandler` that combines two protocols into one. #[derive(Debug, Clone)] @@ -62,7 +61,7 @@ where TProto2: IntoProtocolsHandler, TProto1::Handler: ProtocolsHandler, TProto2::Handler: ProtocolsHandler, - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin, ::InboundProtocol: InboundUpgrade, ::InboundProtocol: InboundUpgrade, ::OutboundProtocol: OutboundUpgrade, @@ -107,7 +106,7 @@ impl where TProto1: ProtocolsHandler, TProto2: ProtocolsHandler, - TSubstream: AsyncRead + AsyncWrite, + TSubstream: AsyncRead + AsyncWrite + Unpin, TProto1::InboundProtocol: InboundUpgrade, TProto2::InboundProtocol: InboundUpgrade, TProto1::OutboundProtocol: OutboundUpgrade, @@ -201,40 +200,46 @@ where cmp::max(self.proto1.connection_keep_alive(), self.proto2.connection_keep_alive()) } - fn poll(&mut self) -> Poll, Self::Error> { + fn poll(&mut self, cx: &mut Context) -> Poll> { - match self.proto1.poll().map_err(EitherError::A)? { - Async::Ready(ProtocolsHandlerEvent::Custom(event)) => { - return Ok(Async::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::First(event)))); + match self.proto1.poll(cx) { + Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { + return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::First(event))); }, - Async::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + Poll::Ready(ProtocolsHandlerEvent::Close(event)) => { + return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::A(event))); + }, + Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info, }) => { - return Ok(Async::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: protocol.map_upgrade(EitherUpgrade::A), info: EitherOutput::First(info), - })); + }); }, - Async::NotReady => () + Poll::Pending => () }; - match self.proto2.poll().map_err(EitherError::B)? { - Async::Ready(ProtocolsHandlerEvent::Custom(event)) => { - return Ok(Async::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::Second(event)))); + match self.proto2.poll(cx) { + Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { + return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::Second(event))); }, - Async::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + Poll::Ready(ProtocolsHandlerEvent::Close(event)) => { + return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::B(event))); + }, + Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol, info, }) => { - return Ok(Async::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: protocol.map_upgrade(EitherUpgrade::B), info: EitherOutput::Second(info), - })); + }); }, - Async::NotReady => () + Poll::Pending => () }; - Ok(Async::NotReady) + Poll::Pending } } diff --git a/swarm/src/toggle.rs b/swarm/src/toggle.rs index 002ab626..c4e42e35 100644 --- a/swarm/src/toggle.rs +++ b/swarm/src/toggle.rs @@ -34,8 +34,7 @@ use libp2p_core::{ either::EitherOutput, upgrade::{InboundUpgrade, OutboundUpgrade, DeniedUpgrade, EitherUpgrade} }; -use futures::prelude::*; -use std::error; +use std::{error, task::Context, task::Poll}; /// Implementation of `NetworkBehaviour` that can be either in the disabled or enabled state. /// @@ -132,13 +131,13 @@ where } } - fn poll(&mut self, params: &mut impl PollParameters) - -> Async::Handler as ProtocolsHandler>::InEvent, Self::OutEvent>> + fn poll(&mut self, cx: &mut Context, params: &mut impl PollParameters) + -> Poll::Handler as ProtocolsHandler>::InEvent, Self::OutEvent>> { if let Some(inner) = self.inner.as_mut() { - inner.poll(params) + inner.poll(cx, params) } else { - Async::NotReady + Poll::Pending } } } @@ -244,14 +243,14 @@ where fn poll( &mut self, + cx: &mut Context, ) -> Poll< - ProtocolsHandlerEvent, - Self::Error, + ProtocolsHandlerEvent > { if let Some(inner) = self.inner.as_mut() { - inner.poll() + inner.poll(cx) } else { - Ok(Async::NotReady) + Poll::Pending } } } diff --git a/transports/dns/Cargo.toml b/transports/dns/Cargo.toml index e1ee1d62..f16cf4d8 100644 --- a/transports/dns/Cargo.toml +++ b/transports/dns/Cargo.toml @@ -12,8 +12,4 @@ categories = ["network-programming", "asynchronous"] [dependencies] libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4.1" -futures = "0.1" -tokio-dns-unofficial = "0.4" - -[dev-dependencies] -libp2p-tcp = { version = "0.12.0", path = "../../transports/tcp" } +futures-preview = "0.3.0-alpha.17" diff --git a/transports/dns/src/lib.rs b/transports/dns/src/lib.rs index 7f0dddfd..95a1db9e 100644 --- a/transports/dns/src/lib.rs +++ b/transports/dns/src/lib.rs @@ -33,15 +33,14 @@ //! replaced with respectively an `/ip4/` or an `/ip6/` component. //! -use futures::{future::{self, Either, FutureResult, JoinAll}, prelude::*, stream, try_ready}; +use futures::{prelude::*, channel::oneshot}; use libp2p_core::{ Transport, multiaddr::{Protocol, Multiaddr}, transport::{TransportError, ListenerEvent} }; -use log::{debug, trace, log_enabled, Level}; -use std::{error, fmt, io, marker::PhantomData, net::IpAddr}; -use tokio_dns::{CpuPoolResolver, Resolver}; +use log::{error, debug, trace}; +use std::{error, fmt, io, net::ToSocketAddrs, pin::Pin}; /// Represents the configuration for a DNS transport capability of libp2p. /// @@ -52,24 +51,31 @@ use tokio_dns::{CpuPoolResolver, Resolver}; /// Listening is unaffected. #[derive(Clone)] pub struct DnsConfig { + /// Underlying transport to use once the DNS addresses have been resolved. inner: T, - resolver: CpuPoolResolver, + /// Pool of threads to use when resolving DNS addresses. + thread_pool: futures::executor::ThreadPool, } impl DnsConfig { /// Creates a new configuration object for DNS. - pub fn new(inner: T) -> DnsConfig { + pub fn new(inner: T) -> Result, io::Error> { DnsConfig::with_resolve_threads(inner, 1) } /// Same as `new`, but allows specifying a number of threads for the resolving. - pub fn with_resolve_threads(inner: T, num_threads: usize) -> DnsConfig { - trace!("Created a CpuPoolResolver"); + pub fn with_resolve_threads(inner: T, num_threads: usize) -> Result, io::Error> { + let thread_pool = futures::executor::ThreadPool::builder() + .pool_size(num_threads) + .name_prefix("libp2p-dns-") + .create()?; - DnsConfig { + trace!("Created a DNS thread pool"); + + Ok(DnsConfig { inner, - resolver: CpuPoolResolver::new(num_threads), - } + thread_pool, + }) } } @@ -84,34 +90,34 @@ where impl Transport for DnsConfig where - T: Transport, + T: Transport + 'static, T::Error: 'static, { type Output = T::Output; type Error = DnsErr; type Listener = stream::MapErr< - stream::Map) -> ListenerEvent>, fn(T::Error) -> Self::Error>; type ListenerUpgrade = future::MapErr Self::Error>; - type Dial = Either Self::Error>, - DialFuture>, T::Error>, - FutureResult, Self::Error>>>> - >> + type Dial = future::Either< + future::MapErr Self::Error>, + Pin>>> >; fn listen_on(self, addr: Multiaddr) -> Result> { let listener = self.inner.listen_on(addr).map_err(|err| err.map(DnsErr::Underlying))?; let listener = listener - .map::<_, fn(_) -> _>(|event| event.map(|upgr| { - upgr.map_err:: _, _>(DnsErr::Underlying) + .map_ok::<_, fn(_) -> _>(|event| event.map(|upgr| { + upgr.map_err::<_, fn(_) -> _>(DnsErr::Underlying) })) .map_err::<_, fn(_) -> _>(DnsErr::Underlying); Ok(listener) } fn dial(self, addr: Multiaddr) -> Result> { + // As an optimization, we immediately pass through if no component of the address contain + // a DNS protocol. let contains_dns = addr.iter().any(|cmp| match cmp { Protocol::Dns4(_) => true, Protocol::Dns6(_) => true, @@ -120,44 +126,61 @@ where if !contains_dns { trace!("Pass-through address without DNS: {}", addr); - let inner_dial = self.inner.dial(addr).map_err(|err| err.map(DnsErr::Underlying))?; - return Ok(Either::A(inner_dial.map_err(DnsErr::Underlying))); + let inner_dial = self.inner.dial(addr) + .map_err(|err| err.map(DnsErr::Underlying))?; + return Ok(inner_dial.map_err::<_, fn(_) -> _>(DnsErr::Underlying).left_future()); } - let resolver = self.resolver; - trace!("Dialing address with DNS: {}", addr); - let resolve_iters = addr.iter() - .map(move |cmp| match cmp { - Protocol::Dns4(ref name) => - Either::A(ResolveFuture { - name: if log_enabled!(Level::Trace) { - Some(name.clone().into_owned()) - } else { - None - }, - inner: resolver.resolve(name), - ty: ResolveTy::Dns4, - error_ty: PhantomData, - }), - Protocol::Dns6(ref name) => - Either::A(ResolveFuture { - name: if log_enabled!(Level::Trace) { - Some(name.clone().into_owned()) - } else { - None - }, - inner: resolver.resolve(name), - ty: ResolveTy::Dns6, - error_ty: PhantomData, - }), - cmp => Either::B(future::ok(cmp.acquire())) - }) - .collect::>() - .into_iter(); + let resolve_futs = addr.iter() + .map(|cmp| match cmp { + Protocol::Dns4(ref name) | Protocol::Dns6(ref name) => { + let name = name.to_string(); + let to_resolve = format!("{}:0", name); + let (tx, rx) = oneshot::channel(); + self.thread_pool.spawn_ok(async { + let to_resolve = to_resolve; + let _ = tx.send(match to_resolve[..].to_socket_addrs() { + Ok(list) => Ok(list.map(|s| s.ip()).collect::>()), + Err(e) => Err(e), + }); + }); - let new_addr = JoinFuture { addr, future: future::join_all(resolve_iters) }; - Ok(Either::B(DialFuture { trans: Some(self.inner), future: Either::A(new_addr) })) + async { + let list = rx.await + .map_err(|_| { + error!("DNS resolver crashed"); + DnsErr::ResolveFail(name.clone()) + })? + .map_err(|err| DnsErr::ResolveError { + domain_name: name.clone(), + error: err, + })?; + + list.into_iter().next() + .map(|n| Protocol::from(n)) // TODO: doesn't take dns4/dns6 into account + .ok_or_else(|| DnsErr::ResolveFail(name)) + }.left_future() + }, + cmp => future::ready(Ok(cmp.acquire())).right_future() + }) + .collect::>(); + + let inner = self.inner; + Ok(future::Either::Right(Box::pin(async { + let addr = addr; + let outcome: Vec<_> = resolve_futs.collect().await; + let outcome = outcome.into_iter().collect::, _>>()?; + let outcome = outcome.into_iter().collect::(); + debug!("DNS resolution outcome: {} => {}", addr, outcome); + + match inner.dial(outcome) { + Ok(d) => d.await.map_err(DnsErr::Underlying), + Err(TransportError::MultiaddrNotSupported(_addr)) => + Err(DnsErr::MultiaddrNotSupported), + Err(TransportError::Other(err)) => Err(DnsErr::Underlying(err)), + } + }) as Pin>)) } } @@ -205,116 +228,17 @@ where TErr: error::Error + 'static } } -// How to resolve; to an IPv4 address or an IPv6 address? -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum ResolveTy { - Dns4, - Dns6, -} - -/// Future, performing DNS resolution. -#[derive(Debug)] -pub struct ResolveFuture { - name: Option, - inner: T, - ty: ResolveTy, - error_ty: PhantomData, -} - -impl Future for ResolveFuture -where - T: Future, Error = io::Error> -{ - type Item = Protocol<'static>; - type Error = DnsErr; - - fn poll(&mut self) -> Poll { - let ty = self.ty; - let addrs = try_ready!(self.inner.poll().map_err(|error| { - let domain_name = self.name.take().unwrap_or_default(); - DnsErr::ResolveError { domain_name, error } - })); - - trace!("DNS component resolution: {:?} => {:?}", self.name, addrs); - let mut addrs = addrs - .into_iter() - .filter_map(move |addr| match (addr, ty) { - (IpAddr::V4(addr), ResolveTy::Dns4) => Some(Protocol::Ip4(addr)), - (IpAddr::V6(addr), ResolveTy::Dns6) => Some(Protocol::Ip6(addr)), - _ => None, - }); - match addrs.next() { - Some(a) => Ok(Async::Ready(a)), - None => Err(DnsErr::ResolveFail(self.name.take().unwrap_or_default())) - } - } -} - -/// Build final multi-address from resolving futures. -#[derive(Debug)] -pub struct JoinFuture { - addr: Multiaddr, - future: T -} - -impl Future for JoinFuture -where - T: Future>> -{ - type Item = Multiaddr; - type Error = T::Error; - - fn poll(&mut self) -> Poll { - let outcome = try_ready!(self.future.poll()); - let outcome: Multiaddr = outcome.into_iter().collect(); - debug!("DNS resolution outcome: {} => {}", self.addr, outcome); - Ok(Async::Ready(outcome)) - } -} - -/// Future, dialing the resolved multi-address. -#[derive(Debug)] -pub struct DialFuture { - trans: Option, - future: Either, -} - -impl Future for DialFuture -where - T: Transport, - F: Future>, - TErr: error::Error, -{ - type Item = T::Output; - type Error = DnsErr; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.future { - Either::A(ref mut f) => { - let addr = try_ready!(f.poll()); - match self.trans.take().unwrap().dial(addr) { - Ok(dial) => Either::B(dial), - Err(_) => return Err(DnsErr::MultiaddrNotSupported) - } - } - Either::B(ref mut f) => return f.poll().map_err(DnsErr::Underlying) - }; - self.future = next - } - } -} - #[cfg(test)] mod tests { - use libp2p_tcp::TcpConfig; - use futures::future; + use super::DnsConfig; + use futures::prelude::*; use libp2p_core::{ Transport, multiaddr::{Protocol, Multiaddr}, - transport::TransportError + transport::ListenerEvent, + transport::TransportError, }; - use super::DnsConfig; + use std::pin::Pin; #[test] fn basic_resolve() { @@ -322,11 +246,11 @@ mod tests { struct CustomTransport; impl Transport for CustomTransport { - type Output = ::Output; - type Error = ::Error; - type Listener = ::Listener; - type ListenerUpgrade = ::ListenerUpgrade; - type Dial = future::Empty; + type Output = (); + type Error = std::io::Error; + type Listener = Pin, Self::Error>>>>; + type ListenerUpgrade = Pin>>>; + type Dial = Pin>>>; fn listen_on(self, _: Multiaddr) -> Result> { unreachable!() @@ -340,22 +264,36 @@ mod tests { _ => panic!(), }; match addr[0] { - Protocol::Dns4(_) => (), - Protocol::Dns6(_) => (), + Protocol::Ip4(_) => (), + Protocol::Ip6(_) => (), _ => panic!(), }; - Ok(future::empty()) + Ok(Box::pin(future::ready(Ok(())))) } } - let transport = DnsConfig::new(CustomTransport); + futures::executor::block_on(async move { + let transport = DnsConfig::new(CustomTransport).unwrap(); - let _ = transport - .clone() - .dial("/dns4/example.com/tcp/20000".parse().unwrap()) - .unwrap(); - let _ = transport - .dial("/dns6/example.com/tcp/20000".parse().unwrap()) - .unwrap(); + let _ = transport + .clone() + .dial("/dns4/example.com/tcp/20000".parse().unwrap()) + .unwrap() + .await + .unwrap(); + + let _ = transport + .clone() + .dial("/dns6/example.com/tcp/20000".parse().unwrap()) + .unwrap() + .await + .unwrap(); + + let _ = transport + .dial("/ip4/1.2.3.4/tcp/20000".parse().unwrap()) + .unwrap() + .await + .unwrap(); + }); } } diff --git a/transports/tcp/Cargo.toml b/transports/tcp/Cargo.toml index 03d3a83d..2b28f8b5 100644 --- a/transports/tcp/Cargo.toml +++ b/transports/tcp/Cargo.toml @@ -10,15 +10,11 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] +async-std = "0.99" bytes = "0.4" get_if_addrs = "0.5.3" ipnet = "2.0.0" libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4.1" -futures = "0.1" -tokio-io = "0.1" -tokio-timer = "0.2" -tokio-tcp = "0.1" - -[dev-dependencies] -tokio = "0.1" +futures-preview = "0.3.0-alpha.17" +futures-timer = "0.3" diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index d42b4f44..adb54396 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -20,8 +20,6 @@ //! Implementation of the libp2p `Transport` trait for TCP/IP. //! -//! Uses [the *tokio* library](https://tokio.rs). -//! //! # Usage //! //! Example: @@ -38,11 +36,13 @@ //! The `TcpConfig` structs implements the `Transport` trait of the `swarm` library. See the //! documentation of `swarm` and of libp2p in general to learn how to use the `Transport` trait. +use async_std::net::TcpStream; use futures::{ - future::{self, Either, FutureResult}, + future::{self, Ready}, + io::Initializer, prelude::*, - stream::{self, Chain, IterOk, Once} }; +use futures_timer::Delay; use get_if_addrs::{IfAddr, get_if_addrs}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use libp2p_core::{ @@ -53,15 +53,13 @@ use libp2p_core::{ use log::{debug, trace}; use std::{ collections::VecDeque, - io::{self, Read, Write}, + io, iter::{self, FromIterator}, net::{IpAddr, SocketAddr}, - time::{Duration, Instant}, - vec::IntoIter + pin::Pin, + task::{Context, Poll}, + time::Duration }; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_timer::Delay; -use tokio_tcp::{ConnectFuture, Incoming, TcpStream}; /// Represents the configuration for a TCP/IP transport capability for libp2p. /// @@ -130,9 +128,9 @@ impl TcpConfig { impl Transport for TcpConfig { type Output = TcpTransStream; type Error = io::Error; - type Listener = TcpListener; - type ListenerUpgrade = FutureResult; - type Dial = TcpDialFut; + type Listener = Pin, io::Error>> + Send>>; + type ListenerUpgrade = Ready>; + type Dial = Pin> + Send>>; fn listen_on(self, addr: Multiaddr) -> Result> { let socket_addr = @@ -142,54 +140,59 @@ impl Transport for TcpConfig { return Err(TransportError::MultiaddrNotSupported(addr)) }; - let listener = tokio_tcp::TcpListener::bind(&socket_addr).map_err(TransportError::Other)?; - let local_addr = listener.local_addr().map_err(TransportError::Other)?; - let port = local_addr.port(); + async fn do_listen(cfg: TcpConfig, socket_addr: SocketAddr) + -> Result>>, io::Error>>, io::Error> + { + let listener = async_std::net::TcpListener::bind(&socket_addr).await?; + let local_addr = listener.local_addr()?; + let port = local_addr.port(); - // Determine all our listen addresses which is either a single local IP address - // or (if a wildcard IP address was used) the addresses of all our interfaces, - // as reported by `get_if_addrs`. - let addrs = - if socket_addr.ip().is_unspecified() { - let addrs = host_addresses(port).map_err(TransportError::Other)?; - debug!("Listening on {:?}", addrs.iter().map(|(_, _, ma)| ma).collect::>()); - Addresses::Many(addrs) - } else { - let ma = ip_to_multiaddr(local_addr.ip(), port); - debug!("Listening on {:?}", ma); - Addresses::One(ma) + // Determine all our listen addresses which is either a single local IP address + // or (if a wildcard IP address was used) the addresses of all our interfaces, + // as reported by `get_if_addrs`. + let addrs = + if socket_addr.ip().is_unspecified() { + let addrs = host_addresses(port)?; + debug!("Listening on {:?}", addrs.iter().map(|(_, _, ma)| ma).collect::>()); + Addresses::Many(addrs) + } else { + let ma = ip_to_multiaddr(local_addr.ip(), port); + debug!("Listening on {:?}", ma); + Addresses::One(ma) + }; + + // Generate `NewAddress` events for each new `Multiaddr`. + let pending = match addrs { + Addresses::One(ref ma) => { + let event = ListenerEvent::NewAddress(ma.clone()); + let mut list = VecDeque::new(); + list.push_back(Ok(event)); + list + } + Addresses::Many(ref aa) => { + aa.iter() + .map(|(_, _, ma)| ma) + .cloned() + .map(ListenerEvent::NewAddress) + .map(Result::Ok) + .collect::>() + } }; - // Generate `NewAddress` events for each new `Multiaddr`. - let events = match addrs { - Addresses::One(ref ma) => { - let event = ListenerEvent::NewAddress(ma.clone()); - Either::A(stream::once(Ok(event))) - } - Addresses::Many(ref aa) => { - let events = aa.iter() - .map(|(_, _, ma)| ma) - .cloned() - .map(ListenerEvent::NewAddress) - .collect::>(); - Either::B(stream::iter_ok(events)) - } - }; + let listen_stream = TcpListenStream { + stream: listener, + pause: None, + pause_duration: cfg.sleep_on_error, + port, + addrs, + pending, + config: cfg + }; - let stream = TcpListenStream { - inner: Listener::new(listener.incoming(), self.sleep_on_error), - port, - addrs, - pending: VecDeque::new(), - config: self - }; + Ok(stream::unfold(listen_stream, |s| s.next().map(Some))) + } - Ok(TcpListener { - inner: match events { - Either::A(e) => Either::A(e.chain(stream)), - Either::B(e) => Either::B(e.chain(stream)) - } - }) + Ok(Box::pin(do_listen(self, socket_addr).try_flatten_stream())) } fn dial(self, addr: Multiaddr) -> Result> { @@ -206,12 +209,13 @@ impl Transport for TcpConfig { debug!("Dialing {}", addr); - let future = TcpDialFut { - inner: TcpStream::connect(&socket_addr), - config: self - }; + async fn do_dial(cfg: TcpConfig, socket_addr: SocketAddr) -> Result { + let stream = TcpStream::connect(&socket_addr).await?; + apply_config(&cfg, &stream)?; + Ok(TcpTransStream { inner: stream }) + } - Ok(future) + Ok(Box::pin(do_dial(self, socket_addr))) } } @@ -270,11 +274,11 @@ fn host_addresses(port: u16) -> io::Result> { /// Applies the socket configuration parameters to a socket. fn apply_config(config: &TcpConfig, socket: &TcpStream) -> Result<(), io::Error> { if let Some(recv_buffer_size) = config.recv_buffer_size { - socket.set_recv_buffer_size(recv_buffer_size)?; + // TODO: socket.set_recv_buffer_size(recv_buffer_size)?; } if let Some(send_buffer_size) = config.send_buffer_size { - socket.set_send_buffer_size(send_buffer_size)?; + // TODO: socket.set_send_buffer_size(send_buffer_size)?; } if let Some(ttl) = config.ttl { @@ -282,7 +286,7 @@ fn apply_config(config: &TcpConfig, socket: &TcpStream) -> Result<(), io::Error> } if let Some(keepalive) = config.keepalive { - socket.set_keepalive(keepalive)?; + // TODO: socket.set_keepalive(keepalive)?; } if let Some(nodelay) = config.nodelay { @@ -292,55 +296,6 @@ fn apply_config(config: &TcpConfig, socket: &TcpStream) -> Result<(), io::Error> Ok(()) } -/// Future that dials a TCP/IP address. -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -pub struct TcpDialFut { - inner: ConnectFuture, - /// Original configuration. - config: TcpConfig, -} - -impl Future for TcpDialFut { - type Item = TcpTransStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(stream)) => { - apply_config(&self.config, &stream)?; - Ok(Async::Ready(TcpTransStream { inner: stream })) - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { - debug!("Error while dialing => {:?}", err); - Err(err) - } - } - } -} - -/// Stream of `ListenerEvent`s. -#[derive(Debug)] -pub struct TcpListener { - inner: Either< - Chain>, io::Error>, TcpListenStream>, - Chain>>, io::Error>, TcpListenStream> - > -} - -impl Stream for TcpListener { - type Item = ListenerEvent>; - type Error = io::Error; - - fn poll(&mut self) -> Poll, Self::Error> { - match self.inner { - Either::A(ref mut it) => it.poll(), - Either::B(ref mut it) => it.poll() - } - } -} - /// Listen address information. #[derive(Debug)] enum Addresses { @@ -350,61 +305,16 @@ enum Addresses { Many(Vec<(IpAddr, IpNet, Multiaddr)>) } -type Buffer = VecDeque>>; +type Buffer = VecDeque>>, io::Error>>; -/// Incoming connection stream which pauses after errors. -#[derive(Debug)] -struct Listener { +/// Stream that listens on an TCP/IP address. +pub struct TcpListenStream { /// The incoming connections. - stream: S, + stream: async_std::net::TcpListener, /// The current pause if any. pause: Option, /// How long to pause after an error. - pause_duration: Duration -} - -impl Listener -where - S: Stream, - S::Error: std::fmt::Display -{ - fn new(stream: S, duration: Duration) -> Self { - Listener { stream, pause: None, pause_duration: duration } - } -} - -impl Stream for Listener -where - S: Stream, - S::Error: std::fmt::Display -{ - type Item = S::Item; - type Error = S::Error; - - /// Polls for incoming connections, pausing if an error is encountered. - fn poll(&mut self) -> Poll, S::Error> { - match self.pause.as_mut().map(|p| p.poll()) { - Some(Ok(Async::NotReady)) => return Ok(Async::NotReady), - Some(Ok(Async::Ready(()))) | Some(Err(_)) => { self.pause.take(); } - None => () - } - - match self.stream.poll() { - Ok(x) => Ok(x), - Err(e) => { - debug!("error accepting incoming connection: {}", e); - self.pause = Some(Delay::new(Instant::now() + self.pause_duration)); - Err(e) - } - } - } -} - -/// Stream that listens on an TCP/IP address. -#[derive(Debug)] -pub struct TcpListenStream { - /// Stream of incoming sockets. - inner: Listener, + pause_duration: Duration, /// The port which we use as our listen port in listener event addresses. port: u16, /// The set of known addresses. @@ -445,7 +355,7 @@ fn check_for_interface_changes( for (ip, _, ma) in old_listen_addrs.iter() { if listen_addrs.iter().find(|(i, ..)| i == ip).is_none() { debug!("Expired listen address: {}", ma); - pending.push_back(ListenerEvent::AddressExpired(ma.clone())); + pending.push_back(Ok(ListenerEvent::AddressExpired(ma.clone()))); } } @@ -453,7 +363,7 @@ fn check_for_interface_changes( for (ip, _, ma) in listen_addrs.iter() { if old_listen_addrs.iter().find(|(i, ..)| i == ip).is_none() { debug!("New listen address: {}", ma); - pending.push_back(ListenerEvent::NewAddress(ma.clone())); + pending.push_back(Ok(ListenerEvent::NewAddress(ma.clone()))); } } @@ -470,21 +380,26 @@ fn check_for_interface_changes( Ok(()) } -impl Stream for TcpListenStream { - type Item = ListenerEvent>; - type Error = io::Error; - - fn poll(&mut self) -> Poll, io::Error> { +impl TcpListenStream { + /// Takes ownership of the listener, and returns the next incoming event and the listener. + async fn next(mut self) -> (Result>>, io::Error>, Self) { loop { if let Some(event) = self.pending.pop_front() { - return Ok(Async::Ready(Some(event))) + return (event, self); } - let sock = match self.inner.poll() { - Ok(Async::Ready(Some(sock))) => sock, - Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => return Err(e) + if let Some(pause) = self.pause.take() { + let _ = pause.await; + } + + // TODO: do we get the peer_addr at the same time? + let (sock, _) = match self.stream.accept().await { + Ok(s) => s, + Err(e) => { + debug!("error accepting incoming connection: {}", e); + self.pause = Some(Delay::new(self.pause_duration)); + return (Err(e), self); + } }; let sock_addr = match sock.peer_addr() { @@ -498,7 +413,9 @@ impl Stream for TcpListenStream { let local_addr = match sock.local_addr() { Ok(sock_addr) => { if let Addresses::Many(ref mut addrs) = self.addrs { - check_for_interface_changes(&sock_addr, self.port, addrs, &mut self.pending)? + if let Err(err) = check_for_interface_changes(&sock_addr, self.port, addrs, &mut self.pending) { + return (Err(err), self); + } } ip_to_multiaddr(sock_addr.ip(), sock_addr.port()) } @@ -513,19 +430,19 @@ impl Stream for TcpListenStream { match apply_config(&self.config, &sock) { Ok(()) => { trace!("Incoming connection from {} at {}", remote_addr, local_addr); - self.pending.push_back(ListenerEvent::Upgrade { + self.pending.push_back(Ok(ListenerEvent::Upgrade { upgrade: future::ok(TcpTransStream { inner: sock }), local_addr, remote_addr - }) + })) } Err(err) => { debug!("Error upgrading incoming connection from {}: {:?}", remote_addr, err); - self.pending.push_back(ListenerEvent::Upgrade { + self.pending.push_back(Ok(ListenerEvent::Upgrade { upgrade: future::err(err), local_addr, remote_addr - }) + })) } } } @@ -538,35 +455,27 @@ pub struct TcpTransStream { inner: TcpStream, } -impl Read for TcpTransStream { - fn read(&mut self, buf: &mut [u8]) -> Result { - self.inner.read(buf) - } -} - impl AsyncRead for TcpTransStream { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) } - fn read_buf(&mut self, buf: &mut B) -> Poll { - self.inner.read_buf(buf) - } -} - -impl Write for TcpTransStream { - fn write(&mut self, buf: &[u8]) -> Result { - self.inner.write(buf) - } - - fn flush(&mut self) -> Result<(), io::Error> { - self.inner.flush() + unsafe fn initializer(&self) -> Initializer { + self.inner.initializer() } } impl AsyncWrite for TcpTransStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - AsyncWrite::shutdown(&mut self.inner) + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_close(Pin::new(&mut self.inner), cx) } } @@ -615,8 +524,7 @@ mod tests { .expect("listener"); // Get the first address. - let addr = listener.by_ref() - .wait() + let addr = futures::executor::block_on_stream(listener.by_ref()) .next() .expect("some event") .expect("no error") @@ -626,7 +534,7 @@ mod tests { // Process all initial `NewAddress` events and make sure they // do not contain wildcard address or port. let server = listener - .take_while(|event| match event { + .take_while(|event| match event.as_ref().unwrap() { ListenerEvent::NewAddress(a) => { let mut iter = a.iter(); match iter.next().expect("ip address") { @@ -639,14 +547,14 @@ mod tests { } else { panic!("No TCP port in address: {}", a) } - Ok(true) + futures::future::ready(true) } - _ => Ok(false) + _ => futures::future::ready(false) }) - .for_each(|_| Ok(())); + .for_each(|_| futures::future::ready(())); let client = TcpConfig::new().dial(addr).expect("dialer"); - tokio::run(server.join(client).map(|_| ()).map_err(|e| panic!("error: {}", e))) + futures::executor::block_on(futures::future::join(server, client)).1.unwrap(); } #[test] @@ -705,8 +613,6 @@ mod tests { std::thread::spawn(move || { let addr = "/ip4/127.0.0.1/tcp/12345".parse::().unwrap(); let tcp = TcpConfig::new(); - let mut rt = Runtime::new().unwrap(); - let handle = rt.handle(); let listener = tcp.listen_on(addr).unwrap() .filter_map(ListenerEvent::into_upgrade) .for_each(|(sock, _)| { @@ -720,12 +626,11 @@ mod tests { // Spawn the future as a concurrent task handle.spawn(handle_conn).unwrap(); - Ok(()) + futures::future::ready(()) }) }); - rt.block_on(listener).unwrap(); - rt.run().unwrap(); + futures::executor::block_on(listener); }); std::thread::sleep(std::time::Duration::from_millis(100)); let addr = "/ip4/127.0.0.1/tcp/12345".parse::().unwrap(); @@ -733,13 +638,12 @@ mod tests { // Obtain a future socket through dialing let socket = tcp.dial(addr.clone()).unwrap(); // Define what to do with the socket once it's obtained - let action = socket.then(|sock| -> Result<(), ()> { + let action = socket.then(|sock| { sock.unwrap().write(&[0x1, 0x2, 0x3]).unwrap(); - Ok(()) + futures::future::ready(()) }); // Execute the future in our event loop - let mut rt = Runtime::new().unwrap(); - let _ = rt.block_on(action).unwrap(); + futures::executor::block_on(action); } #[test] @@ -749,7 +653,7 @@ mod tests { let addr = "/ip4/127.0.0.1/tcp/0".parse::().unwrap(); assert!(addr.to_string().contains("tcp/0")); - let new_addr = tcp.listen_on(addr).unwrap().wait() + let new_addr = futures::executor::block_on_stream(tcp.listen_on(addr).unwrap()) .next() .expect("some event") .expect("no error") @@ -766,7 +670,7 @@ mod tests { let addr: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap(); assert!(addr.to_string().contains("tcp/0")); - let new_addr = tcp.listen_on(addr).unwrap().wait() + let new_addr = futures::executor::block_on_stream(tcp.listen_on(addr).unwrap()) .next() .expect("some event") .expect("no error") diff --git a/transports/uds/Cargo.toml b/transports/uds/Cargo.toml index 7ac9b3cd..2293486a 100644 --- a/transports/uds/Cargo.toml +++ b/transports/uds/Cargo.toml @@ -12,10 +12,11 @@ categories = ["network-programming", "asynchronous"] [target.'cfg(all(unix, not(any(target_os = "emscripten", target_os = "unknown"))))'.dependencies] libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4.1" -futures = "0.1" -tokio-uds = "0.2" +futures-preview = "0.3.0-alpha.17" +romio = "0.3.0-alpha.9" [target.'cfg(all(unix, not(any(target_os = "emscripten", target_os = "unknown"))))'.dev-dependencies] tempfile = "3.0" -tokio = "0.1" -tokio-io = "0.1" + +[dev-dependencies] +async-std = "0.99" diff --git a/transports/uds/src/lib.rs b/transports/uds/src/lib.rs index 4be826ca..76f10dec 100644 --- a/transports/uds/src/lib.rs +++ b/transports/uds/src/lib.rs @@ -20,8 +20,6 @@ //! Implementation of the libp2p `Transport` trait for Unix domain sockets. //! -//! Uses [the *tokio* library](https://tokio.rs). -//! //! # Platform support //! //! This transport only works on Unix platforms. @@ -46,27 +44,27 @@ #![cfg(all(unix, not(any(target_os = "emscripten", target_os = "unknown"))))] -use futures::{future::{self, FutureResult}, prelude::*, try_ready}; +use futures::{prelude::*, ready, future::Ready}; use futures::stream::Stream; use log::debug; -use std::{io, path::PathBuf}; +use romio::uds::{UnixListener, UnixStream}; +use std::{io, path::PathBuf, pin::Pin, task::Context, task::Poll}; use libp2p_core::{ Transport, multiaddr::{Protocol, Multiaddr}, transport::{ListenerEvent, TransportError} }; -use tokio_uds::{UnixListener, UnixStream}; /// Represents the configuration for a Unix domain sockets transport capability for libp2p. /// -/// The Unixs sockets created by libp2p will need to be progressed by running the futures and +/// The Unix sockets created by libp2p will need to be progressed by running the futures and /// streams obtained by libp2p through the tokio reactor. #[derive(Debug, Clone)] pub struct UdsConfig { } impl UdsConfig { - /// Creates a new configuration object for TCP/IP. + /// Creates a new configuration object for Unix domain sockets. #[inline] pub fn new() -> UdsConfig { UdsConfig {} @@ -76,9 +74,9 @@ impl UdsConfig { impl Transport for UdsConfig { type Output = UnixStream; type Error = io::Error; - type Listener = ListenerStream; - type ListenerUpgrade = FutureResult; - type Dial = tokio_uds::ConnectFuture; + type Listener = ListenerStream; + type ListenerUpgrade = Ready>; + type Dial = romio::uds::ConnectFuture; fn listen_on(self, addr: Multiaddr) -> Result> { if let Ok(path) = multiaddr_to_path(&addr) { @@ -145,43 +143,40 @@ pub struct ListenerStream { impl Stream for ListenerStream where - T: Stream + T: TryStream + Unpin { - type Item = ListenerEvent>; - type Error = T::Error; + type Item = Result>>, T::Error>; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.tell_new_addr { self.tell_new_addr = false; - return Ok(Async::Ready(Some(ListenerEvent::NewAddress(self.addr.clone())))) + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))) } - match try_ready!(self.stream.poll()) { + + match ready!(TryStream::try_poll_next(Pin::new(&mut self.stream), cx)) { Some(item) => { debug!("incoming connection on {}", self.addr); - Ok(Async::Ready(Some(ListenerEvent::Upgrade { - upgrade: future::ok(item), + Poll::Ready(Some(Ok(ListenerEvent::Upgrade { + upgrade: future::ready(item), local_addr: self.addr.clone(), remote_addr: self.addr.clone() }))) } - None => Ok(Async::Ready(None)) + None => Poll::Ready(None) } } } #[cfg(test)] mod tests { - use tokio::runtime::current_thread::Runtime; use super::{multiaddr_to_path, UdsConfig}; use futures::prelude::*; use std::{self, borrow::Cow, path::Path}; use libp2p_core::{ Transport, - multiaddr::{Protocol, Multiaddr}, - transport::ListenerEvent + multiaddr::{Protocol, Multiaddr} }; use tempfile; - use tokio_io; #[test] fn multiaddr_to_path_conversion() { @@ -202,64 +197,46 @@ mod tests { #[test] fn communicating_between_dialer_and_listener() { - use std::io::Write; let temp_dir = tempfile::tempdir().unwrap(); let socket = temp_dir.path().join("socket"); let addr = Multiaddr::from(Protocol::Unix(Cow::Owned(socket.to_string_lossy().into_owned()))); let addr2 = addr.clone(); - std::thread::spawn(move || { - let tcp = UdsConfig::new(); - - let mut rt = Runtime::new().unwrap(); - let handle = rt.handle(); - let listener = tcp.listen_on(addr2).unwrap() - .filter_map(ListenerEvent::into_upgrade) - .for_each(|(sock, _)| { - sock.and_then(|sock| { - // Define what to do with the socket that just connected to us - // Which in this case is read 3 bytes - let handle_conn = tokio_io::io::read_exact(sock, [0; 3]) - .map(|(_, buf)| assert_eq!(buf, [1, 2, 3])) - .map_err(|err| panic!("IO error {:?}", err)); - - // Spawn the future as a concurrent task - handle.spawn(handle_conn).unwrap(); + async_std::task::spawn( + UdsConfig::new().listen_on(addr2).unwrap() + .try_filter_map(|ev| future::ok(ev.into_upgrade())) + .try_for_each(|(sock, _)| { + async { + let mut sock = sock.await.unwrap(); + let mut buf = [0u8; 3]; + sock.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, [1, 2, 3]); Ok(()) - }) - }); + } + }) + ); - rt.block_on(listener).unwrap(); - rt.run().unwrap(); + futures::executor::block_on(async { + let uds = UdsConfig::new(); + let mut socket = uds.dial(addr.clone()).unwrap().await.unwrap(); + socket.write(&[0x1, 0x2, 0x3]).await.unwrap(); }); - std::thread::sleep(std::time::Duration::from_millis(100)); - let tcp = UdsConfig::new(); - // Obtain a future socket through dialing - let socket = tcp.dial(addr.clone()).unwrap(); - // Define what to do with the socket once it's obtained - let action = socket.then(|sock| -> Result<(), ()> { - sock.unwrap().write(&[0x1, 0x2, 0x3]).unwrap(); - Ok(()) - }); - // Execute the future in our event loop - let mut rt = Runtime::new().unwrap(); - let _ = rt.block_on(action).unwrap(); } #[test] #[ignore] // TODO: for the moment unix addresses fail to parse fn larger_addr_denied() { - let tcp = UdsConfig::new(); + let uds = UdsConfig::new(); - let addr = "/ip4/127.0.0.1/tcp/12345/unix//foo/bar" + let addr = "/unix//foo/bar" .parse::() .unwrap(); - assert!(tcp.listen_on(addr).is_err()); + assert!(uds.listen_on(addr).is_err()); } #[test] #[ignore] // TODO: for the moment unix addresses fail to parse fn relative_addr_denied() { - assert!("/ip4/127.0.0.1/tcp/12345/unix/./foo/bar".parse::().is_err()); + assert!("/unix/./foo/bar".parse::().is_err()); } } diff --git a/transports/wasm-ext/Cargo.toml b/transports/wasm-ext/Cargo.toml index c7765666..6f649b9b 100644 --- a/transports/wasm-ext/Cargo.toml +++ b/transports/wasm-ext/Cargo.toml @@ -10,10 +10,9 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] -futures = "0.1" +futures-preview = "0.3.0-alpha.17" js-sys = "0.3.19" libp2p-core = { version = "0.12.0", path = "../../core" } parity-send-wrapper = "0.1.0" -tokio-io = "0.1" wasm-bindgen = "0.2.42" -wasm-bindgen-futures = "0.3.19" +wasm-bindgen-futures = { version = "0.3.25", features = ["futures_0_3"] } diff --git a/transports/wasm-ext/src/lib.rs b/transports/wasm-ext/src/lib.rs index a577294b..ffed6e59 100644 --- a/transports/wasm-ext/src/lib.rs +++ b/transports/wasm-ext/src/lib.rs @@ -32,11 +32,12 @@ //! module. //! -use futures::{future::FutureResult, prelude::*, stream::Stream, try_ready}; +use futures::{prelude::*, future::Ready, io::Initializer}; use libp2p_core::{transport::ListenerEvent, transport::TransportError, Multiaddr, Transport}; use parity_send_wrapper::SendWrapper; -use std::{collections::VecDeque, error, fmt, io, mem}; +use std::{collections::VecDeque, error, fmt, io, mem, pin::Pin, task::Context, task::Poll}; use wasm_bindgen::{JsCast, prelude::*}; +use wasm_bindgen_futures::futures_0_3::JsFuture; /// Contains the definition that one must match on the JavaScript side. pub mod ffi { @@ -156,7 +157,7 @@ impl Transport for ExtTransport { type Output = Connection; type Error = JsErr; type Listener = Listen; - type ListenerUpgrade = FutureResult; + type ListenerUpgrade = Ready>; type Dial = Dial; fn listen_on(self, addr: Multiaddr) -> Result> { @@ -200,7 +201,7 @@ impl Transport for ExtTransport { #[must_use = "futures do nothing unless polled"] pub struct Dial { /// A promise that will resolve to a `ffi::Connection` on success. - inner: SendWrapper, + inner: SendWrapper, } impl fmt::Debug for Dial { @@ -210,14 +211,13 @@ impl fmt::Debug for Dial { } impl Future for Dial { - type Item = Connection; - type Error = JsErr; + type Output = Result; - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(connec)) => Ok(Async::Ready(Connection::new(connec.into()))), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err(JsErr::from(err)), + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match Future::poll(Pin::new(&mut *self.inner), cx) { + Poll::Ready(Ok(connec)) => Poll::Ready(Ok(Connection::new(connec.into()))), + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => Poll::Ready(Err(JsErr::from(err))), } } } @@ -228,9 +228,9 @@ pub struct Listen { /// Iterator of `ListenEvent`s. iterator: SendWrapper, /// Promise that will yield the next `ListenEvent`. - next_event: Option>, + next_event: Option>, /// List of events that we are waiting to propagate. - pending_events: VecDeque>>, + pending_events: VecDeque>>>, } impl fmt::Debug for Listen { @@ -240,13 +240,12 @@ impl fmt::Debug for Listen { } impl Stream for Listen { - type Item = ListenerEvent>; - type Error = JsErr; + type Item = Result>>, JsErr>; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { if let Some(ev) = self.pending_events.pop_front() { - return Ok(Async::Ready(Some(ev))); + return Poll::Ready(Some(Ok(ev))); } if self.next_event.is_none() { @@ -258,11 +257,15 @@ impl Stream for Listen { } let event = if let Some(next_event) = self.next_event.as_mut() { - let e = ffi::ListenEvent::from(try_ready!(next_event.poll())); + let e = match Future::poll(Pin::new(&mut **next_event), cx) { + Poll::Ready(Ok(ev)) => ffi::ListenEvent::from(ev), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))), + }; self.next_event = None; e } else { - return Ok(Async::Ready(None)); + return Poll::Ready(None); }; for addr in event @@ -319,7 +322,7 @@ pub struct Connection { /// When we write data using the FFI, a promise is returned containing the moment when the /// underlying transport is ready to accept data again. This promise is stored here. /// If this is `Some`, we must wait until the contained promise is resolved to write again. - previous_write_promise: Option>, + previous_write_promise: Option>, } impl Connection { @@ -341,7 +344,7 @@ enum ConnectionReadState { /// Some data have been read and are waiting to be transferred. Can be empty. PendingData(Vec), /// Waiting for a `Promise` containing the next data. - Waiting(SendWrapper), + Waiting(SendWrapper), /// An error occurred or an earlier read yielded EOF. Finished, } @@ -352,11 +355,15 @@ impl fmt::Debug for Connection { } } -impl io::Read for Connection { - fn read(&mut self, buf: &mut [u8]) -> Result { +impl AsyncRead for Connection { + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } + + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { loop { match mem::replace(&mut self.read_state, ConnectionReadState::Finished) { - ConnectionReadState::Finished => break Err(io::ErrorKind::BrokenPipe.into()), + ConnectionReadState::Finished => break Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), ConnectionReadState::PendingData(ref data) if data.is_empty() => { let iter_next = self.read_iterator.next().map_err(JsErr::from)?; @@ -376,22 +383,23 @@ impl io::Read for Connection { buf.copy_from_slice(&data[..buf.len()]); self.read_state = ConnectionReadState::PendingData(data.split_off(buf.len())); - break Ok(buf.len()); + break Poll::Ready(Ok(buf.len())); } else { let len = data.len(); buf[..len].copy_from_slice(&data); self.read_state = ConnectionReadState::PendingData(Vec::new()); - break Ok(len); + break Poll::Ready(Ok(len)); } } ConnectionReadState::Waiting(mut promise) => { - let data = match promise.poll().map_err(JsErr::from)? { - Async::Ready(ref data) if data.is_null() => break Ok(0), - Async::Ready(data) => data, - Async::NotReady => { + let data = match Future::poll(Pin::new(&mut *promise), cx) { + Poll::Ready(Ok(ref data)) if data.is_null() => break Poll::Ready(Ok(0)), + Poll::Ready(Ok(data)) => data, + Poll::Ready(Err(err)) => break Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + Poll::Pending => { self.read_state = ConnectionReadState::Waiting(promise); - break Err(io::ErrorKind::WouldBlock.into()); + break Poll::Ready(Err(io::ErrorKind::WouldBlock.into())); } }; @@ -402,7 +410,7 @@ impl io::Read for Connection { if data_len <= buf.len() { data.copy_to(&mut buf[..data_len]); self.read_state = ConnectionReadState::PendingData(Vec::new()); - break Ok(data_len); + break Poll::Ready(Ok(data_len)); } else { let mut tmp_buf = vec![0; data_len]; data.copy_to(&mut tmp_buf[..]); @@ -415,23 +423,18 @@ impl io::Read for Connection { } } -impl tokio_io::AsyncRead for Connection { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl io::Write for Connection { - fn write(&mut self, buf: &[u8]) -> Result { +impl AsyncWrite for Connection { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { // Note: as explained in the doc-comments of `Connection`, each call to this function must // map to exactly one call to `self.inner.write()`. if let Some(mut promise) = self.previous_write_promise.take() { - match promise.poll().map_err(JsErr::from)? { - Async::Ready(_) => (), - Async::NotReady => { + match Future::poll(Pin::new(&mut *promise), cx) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + Poll::Pending => { self.previous_write_promise = Some(promise); - return Err(io::ErrorKind::WouldBlock.into()); + return Poll::Pending; } } } @@ -440,20 +443,20 @@ impl io::Write for Connection { self.previous_write_promise = Some(SendWrapper::new( self.inner.write(buf).map_err(JsErr::from)?.into(), )); - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } - fn flush(&mut self) -> Result<(), io::Error> { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { // There's no flushing mechanism. In the FFI we consider that writing implicitly flushes. - Ok(()) + Poll::Ready(Ok(())) } -} -impl tokio_io::AsyncWrite for Connection { - fn shutdown(&mut self) -> Poll<(), io::Error> { + fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { // Shutting down is considered instantaneous. - self.inner.shutdown().map_err(JsErr::from)?; - Ok(Async::Ready(())) + match self.inner.shutdown() { + Ok(()) => Poll::Ready(Ok(())), + Err(err) => Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + } } } diff --git a/transports/websocket/Cargo.toml b/transports/websocket/Cargo.toml index 026fbd56..1d042920 100644 --- a/transports/websocket/Cargo.toml +++ b/transports/websocket/Cargo.toml @@ -11,11 +11,11 @@ categories = ["network-programming", "asynchronous"] [dependencies] bytes = "0.4.6" -futures = "0.1" +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } +futures_codec = "0.2.0" libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4.1" rw-stream-sink = { version = "0.1.1", path = "../../misc/rw-stream-sink" } -tokio-codec = "0.1.1" tokio-io = "0.1.12" tokio-rustls = "0.10.0-alpha.3" soketto = { version = "0.2.3", features = ["deflate"] } @@ -24,4 +24,3 @@ webpki-roots = "0.16.0" [dev-dependencies] libp2p-tcp = { version = "0.12.0", path = "../tcp" } -tokio = "0.1.20" diff --git a/transports/websocket/src/framed.rs b/transports/websocket/src/framed.rs index b82720a1..9f2cf272 100644 --- a/transports/websocket/src/framed.rs +++ b/transports/websocket/src/framed.rs @@ -20,7 +20,8 @@ use bytes::BytesMut; use crate::{error::Error, tls}; -use futures::{future::{self, Either, Loop}, prelude::*, try_ready}; +use futures::{future::{self, Either, Loop}, prelude::*, ready}; +use futures_codec::{Framed, FramedParts}; use libp2p_core::{ Transport, either::EitherOutput, @@ -35,9 +36,7 @@ use soketto::{ extension::deflate::Deflate, handshake::{self, Redirect, Response} }; -use std::{convert::TryFrom, io}; -use tokio_codec::{Framed, FramedParts}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{convert::TryFrom, io, pin::Pin, task::Context, task::Poll}; use tokio_rustls::webpki; use url::Url; @@ -114,9 +113,9 @@ where { type Output = BytesConnection; type Error = Error; - type Listener = Box, Error = Self::Error> + Send>; - type ListenerUpgrade = Box + Send>; - type Dial = Box + Send>; + type Listener = Pin, Self::Error>> + Send>>; + type ListenerUpgrade = Pin> + Send>>; + type Dial = Pin> + Send>>; fn listen_on(self, addr: Multiaddr) -> Result> { let mut inner_addr = addr.clone(); @@ -170,9 +169,9 @@ where Error::Tls(tls::Error::from(e)) }) .map(|s| EitherOutput::First(EitherOutput::Second(s))); - Either::A(future) + Either::Left(future) } else { // continue with plain stream - Either::B(future::ok(EitherOutput::Second(stream))) + Either::Right(future::ok(EitherOutput::Second(stream))) } }) .and_then(move |stream| { @@ -188,7 +187,7 @@ where if let Some(r) = request { trace!("accepting websocket handshake request from {}", remote2); let key = Vec::from(r.key()); - Either::A(framed.send(Ok(handshake::Accept::new(key))) + Either::Left(framed.send(Ok(handshake::Accept::new(key))) .map_err(|e| Error::Base(Box::new(e))) .map(move |f| { trace!("websocket handshake with {} successful", remote2); @@ -200,7 +199,7 @@ where } else { debug!("connection to {} terminated during handshake", remote2); let e: io::Error = io::ErrorKind::ConnectionAborted.into(); - Either::B(future::err(Error::Handshake(Box::new(e)))) + Either::Right(future::err(Error::Handshake(Box::new(e)))) } }) }); @@ -211,7 +210,7 @@ where } } }); - Ok(Box::new(listen) as Box<_>) + Ok(Box::pin(listen) as Box<_>) } fn dial(self, addr: Multiaddr) -> Result> { @@ -226,7 +225,7 @@ where let max_redirects = self.max_redirects; let future = future::loop_fn((addr, self, max_redirects), |(addr, cfg, remaining)| { dial(addr, cfg.clone()).and_then(move |result| match result { - Either::A(redirect) => { + Either::Left(redirect) => { if remaining == 0 { debug!("too many redirects"); return Err(Error::TooManyRedirects) @@ -234,16 +233,16 @@ where let a = location_to_multiaddr(redirect.location())?; Ok(Loop::Continue((a, cfg, remaining - 1))) } - Either::B(conn) => Ok(Loop::Break(conn)) + Either::Right(conn) => Ok(Loop::Break(conn)) }) }); - Ok(Box::new(future) as Box<_>) + Ok(Box::pin(future) as Box<_>) } } /// Attempty to dial the given address and perform a websocket handshake. fn dial(address: Multiaddr, config: WsConfig) - -> impl Future>, Error = Error> + -> impl Future>, Error>> where T: Transport, T::Output: AsyncRead + AsyncWrite @@ -254,7 +253,7 @@ where let (host_port, dns_name) = match host_and_dnsname(&address) { Ok(x) => x, - Err(e) => return Either::A(future::err(e)) + Err(e) => return Either::Left(future::err(e)) }; let mut inner_addr = address.clone(); @@ -264,22 +263,22 @@ where Some(Protocol::Wss(path)) => { if dns_name.is_none() { debug!("no DNS name in {}", address); - return Either::A(future::err(Error::InvalidMultiaddr(address))) + return Either::Left(future::err(Error::InvalidMultiaddr(address))) } (true, path) } _ => { debug!("{} is not a websocket multiaddr", address); - return Either::A(future::err(Error::InvalidMultiaddr(address))) + return Either::Left(future::err(Error::InvalidMultiaddr(address))) } }; let dial = match transport.dial(inner_addr) { Ok(dial) => dial, Err(TransportError::MultiaddrNotSupported(a)) => - return Either::A(future::err(Error::InvalidMultiaddr(a))), + return Either::Left(future::err(Error::InvalidMultiaddr(a))), Err(TransportError::Other(e)) => - return Either::A(future::err(Error::Transport(e))) + return Either::Left(future::err(Error::Transport(e))) }; let address1 = address.clone(); // used for logging @@ -297,10 +296,10 @@ where Error::Tls(tls::Error::from(e)) }) .map(|s| EitherOutput::First(EitherOutput::First(s))); - return Either::A(future) + return Either::Left(future) } // continue with plain stream - Either::B(future::ok(EitherOutput::Second(stream))) + Either::Right(future::ok(EitherOutput::Second(stream))) }) .and_then(move |stream| { trace!("sending websocket handshake request to {}", address1); @@ -324,7 +323,7 @@ where } Some(Response::Redirect(r)) => { debug!("received {}", r); - return Ok(Either::A(r)) + return Ok(Either::Left(r)) } Some(Response::Accepted(_)) => { trace!("websocket handshake with {} successful", address1) @@ -332,11 +331,11 @@ where } let (mut handshake, mut c) = new_connection(framed, max_data_size, Mode::Client); c.add_extensions(handshake.drain_extensions()); - Ok(Either::B(BytesConnection { inner: c })) + Ok(Either::Right(BytesConnection { inner: c })) }) }); - Either::B(future) + Either::Right(future) } // Extract host, port and optionally the DNS name from the given [`Multiaddr`]. @@ -423,36 +422,35 @@ pub struct BytesConnection { } impl Stream for BytesConnection { - type Item = BytesMut; + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let data = ready!(self.inner.poll(cx).map_err(|e| io::Error::new(io::ErrorKind::Other, e))); + Poll::Ready(data.map(base::Data::into_bytes)) + } +} + +impl Sink for BytesConnection { type Error = io::Error; - fn poll(&mut self) -> Poll, Self::Error> { - let data = try_ready!(self.inner.poll().map_err(|e| io::Error::new(io::ErrorKind::Other, e))); - Ok(Async::Ready(data.map(base::Data::into_bytes))) - } -} - -impl Sink for BytesConnection { - type SinkItem = BytesMut; - type SinkError = io::Error; - - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - let result = self.inner.start_send(base::Data::Binary(item)) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)); - - if let AsyncSink::NotReady(data) = result? { - Ok(AsyncSink::NotReady(data.into_bytes())) - } else { - Ok(AsyncSink::Ready) - } - } - - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.inner.poll_complete().map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - } - - fn close(&mut self) -> Poll<(), Self::SinkError> { - self.inner.close().map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_ready(Pin::new(&mut self.inner), cx) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + + fn start_send(self: Pin<&mut Self>, item: BytesMut) -> Result<(), Self::Error> { + self.inner.start_send(base::Data::Binary(item)) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_flush(Pin::new(&mut self.inner), cx) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Sink::poll_close(Pin::new(&mut self.inner), cx) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) } } diff --git a/transports/websocket/src/lib.rs b/transports/websocket/src/lib.rs index 533e1b78..cfc28088 100644 --- a/transports/websocket/src/lib.rs +++ b/transports/websocket/src/lib.rs @@ -34,7 +34,7 @@ use libp2p_core::{ transport::{map::{MapFuture, MapStream}, ListenerEvent, TransportError} }; use rw_stream_sink::RwStreamSink; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::pin::Pin; /// A Websocket transport. #[derive(Debug, Clone)] @@ -117,11 +117,11 @@ where /// Type alias corresponding to `framed::WsConfig::Listener`. pub type InnerStream = - Box<(dyn Stream, Item = ListenerEvent>> + Send)>; + Pin>, Error>> + Send)>>; /// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`. pub type InnerFuture = - Box<(dyn Future, Error = Error> + Send)>; + Pin, Error>> + Send)>>; /// Function type that wraps a websocket connection (see. `wrap_connection`). pub type WrapperFn =