diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 047a7ac4..dc4fd082 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -4,11 +4,13 @@ and `poll_address_change`. Consequently, `StreamMuxerEvent` is also removed. See [PR 2724]. - Drop `Unpin` requirement from `SubstreamBox`. See [PR 2762] and [PR 2776]. - Drop `Sync` requirement on `StreamMuxer` for constructing `StreamMuxerBox`. See [PR 2775]. +- Use `Pin<&mut Self>` as the receiver type for all `StreamMuxer` poll functions. See [PR 2765]. [PR 2724]: https://github.com/libp2p/rust-libp2p/pull/2724 [PR 2762]: https://github.com/libp2p/rust-libp2p/pull/2762 [PR 2775]: https://github.com/libp2p/rust-libp2p/pull/2775 [PR 2776]: https://github.com/libp2p/rust-libp2p/pull/2776 +[PR 2765]: https://github.com/libp2p/rust-libp2p/pull/2765 # 0.34.0 diff --git a/core/src/either.rs b/core/src/either.rs index 4b5c20b2..42984519 100644 --- a/core/src/either.rs +++ b/core/src/either.rs @@ -204,43 +204,54 @@ where type Substream = EitherOutput; type Error = EitherError; - fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll> { - match self { - EitherOutput::First(inner) => inner + fn poll_inbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + EitherOutputProj::First(inner) => inner .poll_inbound(cx) .map_ok(EitherOutput::First) .map_err(EitherError::A), - EitherOutput::Second(inner) => inner + EitherOutputProj::Second(inner) => inner .poll_inbound(cx) .map_ok(EitherOutput::Second) .map_err(EitherError::B), } } - fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll> { - match self { - EitherOutput::First(inner) => inner + fn poll_outbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + EitherOutputProj::First(inner) => inner .poll_outbound(cx) .map_ok(EitherOutput::First) .map_err(EitherError::A), - EitherOutput::Second(inner) => inner + EitherOutputProj::Second(inner) => inner .poll_outbound(cx) .map_ok(EitherOutput::Second) .map_err(EitherError::B), } } - fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll> { - match self { - EitherOutput::First(inner) => inner.poll_address_change(cx).map_err(EitherError::A), - EitherOutput::Second(inner) => inner.poll_address_change(cx).map_err(EitherError::B), + fn poll_address_change( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + EitherOutputProj::First(inner) => inner.poll_address_change(cx).map_err(EitherError::A), + EitherOutputProj::Second(inner) => { + inner.poll_address_change(cx).map_err(EitherError::B) + } } } - fn poll_close(&self, cx: &mut Context<'_>) -> Poll> { - match self { - EitherOutput::First(inner) => inner.poll_close(cx).map_err(EitherError::A), - EitherOutput::Second(inner) => inner.poll_close(cx).map_err(EitherError::B), + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + EitherOutputProj::First(inner) => inner.poll_close(cx).map_err(EitherError::A), + EitherOutputProj::Second(inner) => inner.poll_close(cx).map_err(EitherError::B), } } } diff --git a/core/src/muxing.rs b/core/src/muxing.rs index a2bdfa80..2d1e1068 100644 --- a/core/src/muxing.rs +++ b/core/src/muxing.rs @@ -52,6 +52,8 @@ use futures::{task::Context, task::Poll, AsyncRead, AsyncWrite}; use multiaddr::Multiaddr; +use std::future::Future; +use std::pin::Pin; pub use self::boxed::StreamMuxerBox; pub use self::boxed::SubstreamBox; @@ -73,15 +75,24 @@ pub trait StreamMuxer { type Error: std::error::Error; /// Poll for new inbound substreams. - fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll>; + fn poll_inbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; /// Poll for a new, outbound substream. - fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll>; + fn poll_outbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; /// Poll for an address change of the underlying connection. /// /// Not all implementations may support this feature. - fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll>; + fn poll_address_change( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; /// Closes this `StreamMuxer`. /// @@ -93,5 +104,105 @@ pub trait StreamMuxer { /// > that the remote is properly informed of the shutdown. However, apart from /// > properly informing the remote, there is no difference between this and /// > immediately dropping the muxer. - fn poll_close(&self, cx: &mut Context<'_>) -> Poll>; + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +/// Extension trait for [`StreamMuxer`]. +pub trait StreamMuxerExt: StreamMuxer + Sized { + /// Convenience function for calling [`StreamMuxer::poll_inbound`] for [`StreamMuxer`]s that are `Unpin`. + fn poll_inbound_unpin( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Unpin, + { + Pin::new(self).poll_inbound(cx) + } + + /// Convenience function for calling [`StreamMuxer::poll_outbound`] for [`StreamMuxer`]s that are `Unpin`. + fn poll_outbound_unpin( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Unpin, + { + Pin::new(self).poll_outbound(cx) + } + + /// Convenience function for calling [`StreamMuxer::poll_address_change`] for [`StreamMuxer`]s that are `Unpin`. + fn poll_address_change_unpin( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Unpin, + { + Pin::new(self).poll_address_change(cx) + } + + /// Convenience function for calling [`StreamMuxer::poll_close`] for [`StreamMuxer`]s that are `Unpin`. + fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll> + where + Self: Unpin, + { + Pin::new(self).poll_close(cx) + } + + /// Returns a future that resolves to the next inbound `Substream` opened by the remote. + fn next_inbound(&mut self) -> NextInbound<'_, Self> { + NextInbound(self) + } + + /// Returns a future that opens a new outbound `Substream` with the remote. + fn next_outbound(&mut self) -> NextOutbound<'_, Self> { + NextOutbound(self) + } + + /// Returns a future for closing this [`StreamMuxer`]. + fn close(self) -> Close { + Close(self) + } +} + +impl StreamMuxerExt for S where S: StreamMuxer {} + +pub struct NextInbound<'a, S>(&'a mut S); + +pub struct NextOutbound<'a, S>(&'a mut S); + +pub struct Close(S); + +impl<'a, S> Future for NextInbound<'a, S> +where + S: StreamMuxer + Unpin, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.poll_inbound_unpin(cx) + } +} + +impl<'a, S> Future for NextOutbound<'a, S> +where + S: StreamMuxer + Unpin, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.poll_outbound_unpin(cx) + } +} + +impl Future for Close +where + S: StreamMuxer + Unpin, +{ + type Output = Result<(), S::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.poll_close_unpin(cx) + } } diff --git a/core/src/muxing/boxed.rs b/core/src/muxing/boxed.rs index 8c6467dd..0f5b6e58 100644 --- a/core/src/muxing/boxed.rs +++ b/core/src/muxing/boxed.rs @@ -1,6 +1,7 @@ use crate::StreamMuxer; use futures::{AsyncRead, AsyncWrite}; use multiaddr::Multiaddr; +use pin_project::pin_project; use std::error::Error; use std::fmt; use std::io; @@ -10,7 +11,7 @@ use std::task::{Context, Poll}; /// Abstract `StreamMuxer`. pub struct StreamMuxerBox { - inner: Box + Send>, + inner: Pin + Send>>, } /// Abstract type for asynchronous reading and writing. @@ -19,10 +20,12 @@ pub struct StreamMuxerBox { /// and `AsyncWrite` capabilities. pub struct SubstreamBox(Pin>); +#[pin_project] struct Wrap where T: StreamMuxer, { + #[pin] inner: T, } @@ -36,26 +39,40 @@ where type Error = io::Error; #[inline] - fn poll_close(&self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx).map_err(into_io_error) + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx).map_err(into_io_error) } - fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll> { - self.inner + fn poll_inbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project() + .inner .poll_inbound(cx) .map_ok(SubstreamBox::new) .map_err(into_io_error) } - fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll> { - self.inner + fn poll_outbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project() + .inner .poll_outbound(cx) .map_ok(SubstreamBox::new) .map_err(into_io_error) } - fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_address_change(cx).map_err(into_io_error) + fn poll_address_change( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project() + .inner + .poll_address_change(cx) + .map_err(into_io_error) } } @@ -77,9 +94,15 @@ impl StreamMuxerBox { let wrap = Wrap { inner: muxer }; StreamMuxerBox { - inner: Box::new(wrap), + inner: Box::pin(wrap), } } + + fn project( + self: Pin<&mut Self>, + ) -> Pin<&mut (dyn StreamMuxer + Send)> { + self.get_mut().inner.as_mut() + } } impl StreamMuxer for StreamMuxerBox { @@ -87,20 +110,29 @@ impl StreamMuxer for StreamMuxerBox { type Error = io::Error; #[inline] - fn poll_close(&self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx) + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().poll_close(cx) } - fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_inbound(cx) + fn poll_inbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().poll_inbound(cx) } - fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_outbound(cx) + fn poll_outbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().poll_outbound(cx) } - fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_address_change(cx) + fn poll_address_change( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().poll_address_change(cx) } } diff --git a/core/src/muxing/singleton.rs b/core/src/muxing/singleton.rs index d67cb5e9..193cfb63 100644 --- a/core/src/muxing/singleton.rs +++ b/core/src/muxing/singleton.rs @@ -23,6 +23,7 @@ use crate::{connection::Endpoint, muxing::StreamMuxer}; use futures::prelude::*; use multiaddr::Multiaddr; use std::cell::Cell; +use std::pin::Pin; use std::{io, task::Context, task::Poll}; /// Implementation of `StreamMuxer` that allows only one substream on top of a connection, @@ -57,31 +58,44 @@ where type Substream = TSocket; type Error = io::Error; - fn poll_inbound(&self, _: &mut Context<'_>) -> Poll> { - match self.endpoint { + fn poll_inbound( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + match this.endpoint { Endpoint::Dialer => Poll::Pending, - Endpoint::Listener => match self.inner.replace(None) { + Endpoint::Listener => match this.inner.replace(None) { None => Poll::Pending, Some(stream) => Poll::Ready(Ok(stream)), }, } } - fn poll_outbound(&self, _: &mut Context<'_>) -> Poll> { - match self.endpoint { + fn poll_outbound( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + match this.endpoint { Endpoint::Listener => Poll::Pending, - Endpoint::Dialer => match self.inner.replace(None) { + Endpoint::Dialer => match this.inner.replace(None) { None => Poll::Pending, Some(stream) => Poll::Ready(Ok(stream)), }, } } - fn poll_address_change(&self, _: &mut Context<'_>) -> Poll> { + fn poll_address_change( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { Poll::Pending } - fn poll_close(&self, _cx: &mut Context<'_>) -> Poll> { + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } diff --git a/muxers/mplex/benches/split_send_size.rs b/muxers/mplex/benches/split_send_size.rs index d536edf4..f74bcd10 100644 --- a/muxers/mplex/benches/split_send_size.rs +++ b/muxers/mplex/benches/split_send_size.rs @@ -26,9 +26,9 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughpu use futures::future::poll_fn; use futures::prelude::*; use futures::{channel::oneshot, future::join}; +use libp2p_core::muxing::StreamMuxerExt; use libp2p_core::{ - identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, StreamMuxer, - Transport, + identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, Transport, }; use libp2p_mplex as mplex; use libp2p_plaintext::PlainText2Config; @@ -113,10 +113,8 @@ fn run( addr_sender.take().unwrap().send(listen_addr).unwrap(); } transport::TransportEvent::Incoming { upgrade, .. } => { - let (_peer, conn) = upgrade.await.unwrap(); - let mut s = poll_fn(|cx| conn.poll_inbound(cx)) - .await - .expect("unexpected error"); + let (_peer, mut conn) = upgrade.await.unwrap(); + let mut s = conn.next_inbound().await.expect("unexpected error"); let mut buf = vec![0u8; payload_len]; let mut off = 0; @@ -140,8 +138,8 @@ fn run( // Spawn and block on the sender, i.e. until all data is sent. let sender = async move { let addr = addr_receiver.await.unwrap(); - let (_peer, conn) = sender_trans.dial(addr).unwrap().await.unwrap(); - let mut stream = poll_fn(|cx| conn.poll_outbound(cx)).await.unwrap(); + let (_peer, mut conn) = sender_trans.dial(addr).unwrap().await.unwrap(); + let mut stream = conn.next_outbound().await.unwrap(); let mut off = 0; loop { let n = poll_fn(|cx| Pin::new(&mut stream).poll_write(cx, &payload[off..])) diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index 59b38db1..14f9cda6 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -85,25 +85,34 @@ where type Substream = Substream; type Error = io::Error; - fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll> { + fn poll_inbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { self.io .lock() .poll_next_stream(cx) .map_ok(|stream_id| Substream::new(stream_id, self.io.clone())) } - fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll> { + fn poll_outbound( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { self.io .lock() .poll_open_stream(cx) .map_ok(|stream_id| Substream::new(stream_id, self.io.clone())) } - fn poll_address_change(&self, _: &mut Context<'_>) -> Poll> { + fn poll_address_change( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { Poll::Pending } - fn poll_close(&self, cx: &mut Context<'_>) -> Poll> { + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.io.lock().poll_close(cx) } } diff --git a/muxers/mplex/tests/async_write.rs b/muxers/mplex/tests/async_write.rs index 2c4a2d10..bfbabf0f 100644 --- a/muxers/mplex/tests/async_write.rs +++ b/muxers/mplex/tests/async_write.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::future::poll_fn; use futures::{channel::oneshot, prelude::*}; -use libp2p_core::{upgrade, StreamMuxer, Transport}; +use libp2p_core::muxing::StreamMuxerExt; +use libp2p_core::{upgrade, Transport}; use libp2p_tcp::TcpTransport; #[test] @@ -49,7 +49,7 @@ fn async_write() { tx.send(addr).unwrap(); - let client = transport + let mut client = transport .next() .await .expect("some event") @@ -59,7 +59,7 @@ fn async_write() { .await .unwrap(); - let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); + let mut outbound = client.next_outbound().await.unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); @@ -71,8 +71,9 @@ fn async_write() { let mut transport = TcpTransport::default() .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); - let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let mut inbound = poll_fn(|cx| client.poll_inbound(cx)).await.unwrap(); + let mut client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); + + let mut inbound = client.next_inbound().await.unwrap(); inbound.write_all(b"hello world").await.unwrap(); // The test consists in making sure that this flushes the substream. diff --git a/muxers/mplex/tests/two_peers.rs b/muxers/mplex/tests/two_peers.rs index 2b976c12..d30fcc10 100644 --- a/muxers/mplex/tests/two_peers.rs +++ b/muxers/mplex/tests/two_peers.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::future::poll_fn; use futures::{channel::oneshot, prelude::*}; -use libp2p_core::{upgrade, StreamMuxer, Transport}; +use libp2p_core::muxing::StreamMuxerExt; +use libp2p_core::{upgrade, Transport}; use libp2p_tcp::TcpTransport; #[test] @@ -49,7 +49,7 @@ fn client_to_server_outbound() { tx.send(addr).unwrap(); - let client = transport + let mut client = transport .next() .await .expect("some event") @@ -59,7 +59,7 @@ fn client_to_server_outbound() { .await .unwrap(); - let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); + let mut outbound = client.next_outbound().await.unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); @@ -72,8 +72,8 @@ fn client_to_server_outbound() { .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) .boxed(); - let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let mut inbound = poll_fn(|cx| client.poll_inbound(cx)).await.unwrap(); + let mut client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); + let mut inbound = client.next_inbound().await.unwrap(); inbound.write_all(b"hello world").await.unwrap(); inbound.close().await.unwrap(); @@ -107,7 +107,7 @@ fn client_to_server_inbound() { tx.send(addr).unwrap(); - let client = transport + let mut client = transport .next() .await .expect("some event") @@ -117,7 +117,7 @@ fn client_to_server_inbound() { .await .unwrap(); - let mut inbound = poll_fn(|cx| client.poll_inbound(cx)).await.unwrap(); + let mut inbound = client.next_inbound().await.unwrap(); let mut buf = Vec::new(); inbound.read_to_end(&mut buf).await.unwrap(); @@ -130,9 +130,9 @@ fn client_to_server_inbound() { .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) .boxed(); - let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); + let mut client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); + let mut outbound = client.next_outbound().await.unwrap(); outbound.write_all(b"hello world").await.unwrap(); outbound.close().await.unwrap(); @@ -164,7 +164,7 @@ fn protocol_not_match() { tx.send(addr).unwrap(); - let client = transport + let mut client = transport .next() .await .expect("some event") @@ -174,7 +174,7 @@ fn protocol_not_match() { .await .unwrap(); - let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); + let mut outbound = client.next_outbound().await.unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); diff --git a/muxers/yamux/src/lib.rs b/muxers/yamux/src/lib.rs index a06e7934..07327e20 100644 --- a/muxers/yamux/src/lib.rs +++ b/muxers/yamux/src/lib.rs @@ -29,7 +29,6 @@ use futures::{ use libp2p_core::muxing::StreamMuxer; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p_core::Multiaddr; -use parking_lot::Mutex; use std::{ fmt, io, iter, mem, pin::Pin, @@ -39,7 +38,12 @@ use thiserror::Error; use yamux::ConnectionError; /// A Yamux connection. -pub struct Yamux(Mutex>); +pub struct Yamux { + /// The [`futures::stream::Stream`] of incoming substreams. + incoming: S, + /// Handle to control the connection. + control: yamux::Control, +} impl fmt::Debug for Yamux { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -47,13 +51,6 @@ impl fmt::Debug for Yamux { } } -struct Inner { - /// The [`futures::stream::Stream`] of incoming substreams. - incoming: S, - /// Handle to control the connection. - control: yamux::Control, -} - /// A token to poll for an outbound substream. #[derive(Debug)] pub struct OpenSubstreamToken(()); @@ -66,14 +63,14 @@ where fn new(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { let conn = yamux::Connection::new(io, cfg, mode); let ctrl = conn.control(); - let inner = Inner { + + Yamux { incoming: Incoming { stream: yamux::into_stream(conn).err_into().boxed(), _marker: std::marker::PhantomData, }, control: ctrl, - }; - Yamux(Mutex::new(inner)) + } } } @@ -85,14 +82,14 @@ where fn local(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { let conn = yamux::Connection::new(io, cfg, mode); let ctrl = conn.control(); - let inner = Inner { + + Yamux { incoming: LocalIncoming { stream: yamux::into_stream(conn).err_into().boxed_local(), _marker: std::marker::PhantomData, }, control: ctrl, - }; - Yamux(Mutex::new(inner)) + } } } @@ -105,41 +102,44 @@ where type Substream = yamux::Stream; type Error = YamuxError; - fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll> { - self.0 - .lock() - .incoming - .poll_next_unpin(cx) - .map(|maybe_stream| { - let stream = maybe_stream - .transpose()? - .ok_or(YamuxError(ConnectionError::Closed))?; + fn poll_inbound( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.incoming.poll_next_unpin(cx).map(|maybe_stream| { + let stream = maybe_stream + .transpose()? + .ok_or(YamuxError(ConnectionError::Closed))?; - Ok(stream) - }) + Ok(stream) + }) } - fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0.lock().control) + fn poll_outbound( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.control) .poll_open_stream(cx) .map_err(YamuxError) } - fn poll_address_change(&self, _: &mut Context<'_>) -> Poll> { + fn poll_address_change( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { Poll::Pending } - fn poll_close(&self, c: &mut Context<'_>) -> Poll> { - let mut inner = self.0.lock(); - - if let Poll::Ready(()) = Pin::new(&mut inner.control) + fn poll_close(mut self: Pin<&mut Self>, c: &mut Context<'_>) -> Poll> { + if let Poll::Ready(()) = Pin::new(&mut self.control) .poll_close(c) .map_err(YamuxError)? { return Poll::Ready(Ok(())); } - while let Poll::Ready(maybe_inbound_stream) = inner.incoming.poll_next_unpin(c)? { + while let Poll::Ready(maybe_inbound_stream) = self.incoming.poll_next_unpin(c)? { match maybe_inbound_stream { Some(inbound_stream) => mem::drop(inbound_stream), None => return Poll::Ready(Ok(())), diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index 8d29ca53..f9218661 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -32,13 +32,12 @@ pub use pool::{EstablishedConnection, PendingConnection}; use crate::handler::ConnectionHandler; use crate::IntoConnectionHandler; -use futures::future::poll_fn; use handler_wrapper::HandlerWrapper; use libp2p_core::connection::ConnectedPoint; use libp2p_core::multiaddr::Multiaddr; -use libp2p_core::muxing::StreamMuxerBox; +use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt}; +use libp2p_core::upgrade; use libp2p_core::PeerId; -use libp2p_core::{upgrade, StreamMuxer}; use std::collections::VecDeque; use std::future::Future; use std::{error::Error, fmt, io, pin::Pin, task::Context, task::Poll}; @@ -132,10 +131,7 @@ where /// Begins an orderly shutdown of the connection, returning the connection /// handler and a `Future` that resolves when connection shutdown is complete. pub fn close(self) -> (THandler, impl Future>) { - ( - self.handler.into_connection_handler(), - poll_fn(move |cx| self.muxing.poll_close(cx)), - ) + (self.handler.into_connection_handler(), self.muxing.close()) } /// Polls the handler and the substream, forwarding events from the former to the latter and @@ -158,7 +154,7 @@ where } if !self.open_info.is_empty() { - if let Poll::Ready(substream) = self.muxing.poll_outbound(cx)? { + if let Poll::Ready(substream) = self.muxing.poll_outbound_unpin(cx)? { let user_data = self .open_info .pop_front() @@ -169,13 +165,13 @@ where } } - if let Poll::Ready(substream) = self.muxing.poll_inbound(cx)? { + if let Poll::Ready(substream) = self.muxing.poll_inbound_unpin(cx)? { self.handler .inject_substream(substream, SubstreamEndpoint::Listener); continue; // Go back to the top, handler can potentially make progress again. } - if let Poll::Ready(address) = self.muxing.poll_address_change(cx)? { + if let Poll::Ready(address) = self.muxing.poll_address_change_unpin(cx)? { self.handler.inject_address_change(&address); return Poll::Ready(Ok(Event::AddressChange(address))); } diff --git a/swarm/src/connection/pool.rs b/swarm/src/connection/pool.rs index 4bbdf9c4..62e931e9 100644 --- a/swarm/src/connection/pool.rs +++ b/swarm/src/connection/pool.rs @@ -38,7 +38,7 @@ use futures::{ stream::FuturesUnordered, }; use libp2p_core::connection::{ConnectionId, Endpoint, PendingPoint}; -use libp2p_core::muxing::{StreamMuxer, StreamMuxerBox}; +use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt}; use std::{ collections::{hash_map, HashMap}, convert::TryFrom as _, @@ -604,7 +604,7 @@ where match event { task::PendingConnectionEvent::ConnectionEstablished { id, - output: (obtained_peer_id, muxer), + output: (obtained_peer_id, mut muxer), outgoing, } => { let PendingConnectionInfo { @@ -692,7 +692,7 @@ where if let Err(error) = error { self.spawn( poll_fn(move |cx| { - if let Err(e) = ready!(muxer.poll_close(cx)) { + if let Err(e) = ready!(muxer.poll_close_unpin(cx)) { log::debug!( "Failed to close connection {:?} to peer {}: {:?}", id,