core/muxing: Have functions on StreamMuxer take Pin<&mut Self> (#2765)

Co-authored-by: Elena Frank <elena.frank@protonmail.com>
Co-authored-by: Max Inden <mail@max-inden.de>
This commit is contained in:
Thomas Eizinger 2022-08-03 15:12:11 +02:00 committed by GitHub
parent 2b9e212682
commit 028decec69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 297 additions and 123 deletions

View File

@ -4,11 +4,13 @@
and `poll_address_change`. Consequently, `StreamMuxerEvent` is also removed. See [PR 2724]. and `poll_address_change`. Consequently, `StreamMuxerEvent` is also removed. See [PR 2724].
- Drop `Unpin` requirement from `SubstreamBox`. See [PR 2762] and [PR 2776]. - Drop `Unpin` requirement from `SubstreamBox`. See [PR 2762] and [PR 2776].
- Drop `Sync` requirement on `StreamMuxer` for constructing `StreamMuxerBox`. See [PR 2775]. - Drop `Sync` requirement on `StreamMuxer` for constructing `StreamMuxerBox`. See [PR 2775].
- Use `Pin<&mut Self>` as the receiver type for all `StreamMuxer` poll functions. See [PR 2765].
[PR 2724]: https://github.com/libp2p/rust-libp2p/pull/2724 [PR 2724]: https://github.com/libp2p/rust-libp2p/pull/2724
[PR 2762]: https://github.com/libp2p/rust-libp2p/pull/2762 [PR 2762]: https://github.com/libp2p/rust-libp2p/pull/2762
[PR 2775]: https://github.com/libp2p/rust-libp2p/pull/2775 [PR 2775]: https://github.com/libp2p/rust-libp2p/pull/2775
[PR 2776]: https://github.com/libp2p/rust-libp2p/pull/2776 [PR 2776]: https://github.com/libp2p/rust-libp2p/pull/2776
[PR 2765]: https://github.com/libp2p/rust-libp2p/pull/2765
# 0.34.0 # 0.34.0

View File

@ -204,43 +204,54 @@ where
type Substream = EitherOutput<A::Substream, B::Substream>; type Substream = EitherOutput<A::Substream, B::Substream>;
type Error = EitherError<A::Error, B::Error>; type Error = EitherError<A::Error, B::Error>;
fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_inbound(
match self { self: Pin<&mut Self>,
EitherOutput::First(inner) => inner cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
match self.project() {
EitherOutputProj::First(inner) => inner
.poll_inbound(cx) .poll_inbound(cx)
.map_ok(EitherOutput::First) .map_ok(EitherOutput::First)
.map_err(EitherError::A), .map_err(EitherError::A),
EitherOutput::Second(inner) => inner EitherOutputProj::Second(inner) => inner
.poll_inbound(cx) .poll_inbound(cx)
.map_ok(EitherOutput::Second) .map_ok(EitherOutput::Second)
.map_err(EitherError::B), .map_err(EitherError::B),
} }
} }
fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_outbound(
match self { self: Pin<&mut Self>,
EitherOutput::First(inner) => inner cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
match self.project() {
EitherOutputProj::First(inner) => inner
.poll_outbound(cx) .poll_outbound(cx)
.map_ok(EitherOutput::First) .map_ok(EitherOutput::First)
.map_err(EitherError::A), .map_err(EitherError::A),
EitherOutput::Second(inner) => inner EitherOutputProj::Second(inner) => inner
.poll_outbound(cx) .poll_outbound(cx)
.map_ok(EitherOutput::Second) .map_ok(EitherOutput::Second)
.map_err(EitherError::B), .map_err(EitherError::B),
} }
} }
fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>> { fn poll_address_change(
match self { self: Pin<&mut Self>,
EitherOutput::First(inner) => inner.poll_address_change(cx).map_err(EitherError::A), cx: &mut Context<'_>,
EitherOutput::Second(inner) => inner.poll_address_change(cx).map_err(EitherError::B), ) -> Poll<Result<Multiaddr, Self::Error>> {
match self.project() {
EitherOutputProj::First(inner) => inner.poll_address_change(cx).map_err(EitherError::A),
EitherOutputProj::Second(inner) => {
inner.poll_address_change(cx).map_err(EitherError::B)
}
} }
} }
fn poll_close(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self { match self.project() {
EitherOutput::First(inner) => inner.poll_close(cx).map_err(EitherError::A), EitherOutputProj::First(inner) => inner.poll_close(cx).map_err(EitherError::A),
EitherOutput::Second(inner) => inner.poll_close(cx).map_err(EitherError::B), EitherOutputProj::Second(inner) => inner.poll_close(cx).map_err(EitherError::B),
} }
} }
} }

View File

@ -52,6 +52,8 @@
use futures::{task::Context, task::Poll, AsyncRead, AsyncWrite}; use futures::{task::Context, task::Poll, AsyncRead, AsyncWrite};
use multiaddr::Multiaddr; use multiaddr::Multiaddr;
use std::future::Future;
use std::pin::Pin;
pub use self::boxed::StreamMuxerBox; pub use self::boxed::StreamMuxerBox;
pub use self::boxed::SubstreamBox; pub use self::boxed::SubstreamBox;
@ -73,15 +75,24 @@ pub trait StreamMuxer {
type Error: std::error::Error; type Error: std::error::Error;
/// Poll for new inbound substreams. /// Poll for new inbound substreams.
fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>>; fn poll_inbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>>;
/// Poll for a new, outbound substream. /// Poll for a new, outbound substream.
fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>>; fn poll_outbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>>;
/// Poll for an address change of the underlying connection. /// Poll for an address change of the underlying connection.
/// ///
/// Not all implementations may support this feature. /// Not all implementations may support this feature.
fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>>; fn poll_address_change(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>>;
/// Closes this `StreamMuxer`. /// Closes this `StreamMuxer`.
/// ///
@ -93,5 +104,105 @@ pub trait StreamMuxer {
/// > that the remote is properly informed of the shutdown. However, apart from /// > that the remote is properly informed of the shutdown. However, apart from
/// > properly informing the remote, there is no difference between this and /// > properly informing the remote, there is no difference between this and
/// > immediately dropping the muxer. /// > immediately dropping the muxer.
fn poll_close(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>; fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
}
/// Extension trait for [`StreamMuxer`].
pub trait StreamMuxerExt: StreamMuxer + Sized {
/// Convenience function for calling [`StreamMuxer::poll_inbound`] for [`StreamMuxer`]s that are `Unpin`.
fn poll_inbound_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>>
where
Self: Unpin,
{
Pin::new(self).poll_inbound(cx)
}
/// Convenience function for calling [`StreamMuxer::poll_outbound`] for [`StreamMuxer`]s that are `Unpin`.
fn poll_outbound_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>>
where
Self: Unpin,
{
Pin::new(self).poll_outbound(cx)
}
/// Convenience function for calling [`StreamMuxer::poll_address_change`] for [`StreamMuxer`]s that are `Unpin`.
fn poll_address_change_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>>
where
Self: Unpin,
{
Pin::new(self).poll_address_change(cx)
}
/// Convenience function for calling [`StreamMuxer::poll_close`] for [`StreamMuxer`]s that are `Unpin`.
fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>
where
Self: Unpin,
{
Pin::new(self).poll_close(cx)
}
/// Returns a future that resolves to the next inbound `Substream` opened by the remote.
fn next_inbound(&mut self) -> NextInbound<'_, Self> {
NextInbound(self)
}
/// Returns a future that opens a new outbound `Substream` with the remote.
fn next_outbound(&mut self) -> NextOutbound<'_, Self> {
NextOutbound(self)
}
/// Returns a future for closing this [`StreamMuxer`].
fn close(self) -> Close<Self> {
Close(self)
}
}
impl<S> StreamMuxerExt for S where S: StreamMuxer {}
pub struct NextInbound<'a, S>(&'a mut S);
pub struct NextOutbound<'a, S>(&'a mut S);
pub struct Close<S>(S);
impl<'a, S> Future for NextInbound<'a, S>
where
S: StreamMuxer + Unpin,
{
type Output = Result<S::Substream, S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_inbound_unpin(cx)
}
}
impl<'a, S> Future for NextOutbound<'a, S>
where
S: StreamMuxer + Unpin,
{
type Output = Result<S::Substream, S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_outbound_unpin(cx)
}
}
impl<S> Future for Close<S>
where
S: StreamMuxer + Unpin,
{
type Output = Result<(), S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_close_unpin(cx)
}
} }

View File

@ -1,6 +1,7 @@
use crate::StreamMuxer; use crate::StreamMuxer;
use futures::{AsyncRead, AsyncWrite}; use futures::{AsyncRead, AsyncWrite};
use multiaddr::Multiaddr; use multiaddr::Multiaddr;
use pin_project::pin_project;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::io; use std::io;
@ -10,7 +11,7 @@ use std::task::{Context, Poll};
/// Abstract `StreamMuxer`. /// Abstract `StreamMuxer`.
pub struct StreamMuxerBox { pub struct StreamMuxerBox {
inner: Box<dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send>, inner: Pin<Box<dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send>>,
} }
/// Abstract type for asynchronous reading and writing. /// Abstract type for asynchronous reading and writing.
@ -19,10 +20,12 @@ pub struct StreamMuxerBox {
/// and `AsyncWrite` capabilities. /// and `AsyncWrite` capabilities.
pub struct SubstreamBox(Pin<Box<dyn AsyncReadWrite + Send>>); pub struct SubstreamBox(Pin<Box<dyn AsyncReadWrite + Send>>);
#[pin_project]
struct Wrap<T> struct Wrap<T>
where where
T: StreamMuxer, T: StreamMuxer,
{ {
#[pin]
inner: T, inner: T,
} }
@ -36,26 +39,40 @@ where
type Error = io::Error; type Error = io::Error;
#[inline] #[inline]
fn poll_close(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_close(cx).map_err(into_io_error) self.project().inner.poll_close(cx).map_err(into_io_error)
} }
fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_inbound(
self.inner self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
self.project()
.inner
.poll_inbound(cx) .poll_inbound(cx)
.map_ok(SubstreamBox::new) .map_ok(SubstreamBox::new)
.map_err(into_io_error) .map_err(into_io_error)
} }
fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_outbound(
self.inner self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
self.project()
.inner
.poll_outbound(cx) .poll_outbound(cx)
.map_ok(SubstreamBox::new) .map_ok(SubstreamBox::new)
.map_err(into_io_error) .map_err(into_io_error)
} }
fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>> { fn poll_address_change(
self.inner.poll_address_change(cx).map_err(into_io_error) self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>> {
self.project()
.inner
.poll_address_change(cx)
.map_err(into_io_error)
} }
} }
@ -77,9 +94,15 @@ impl StreamMuxerBox {
let wrap = Wrap { inner: muxer }; let wrap = Wrap { inner: muxer };
StreamMuxerBox { StreamMuxerBox {
inner: Box::new(wrap), inner: Box::pin(wrap),
} }
} }
fn project(
self: Pin<&mut Self>,
) -> Pin<&mut (dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send)> {
self.get_mut().inner.as_mut()
}
} }
impl StreamMuxer for StreamMuxerBox { impl StreamMuxer for StreamMuxerBox {
@ -87,20 +110,29 @@ impl StreamMuxer for StreamMuxerBox {
type Error = io::Error; type Error = io::Error;
#[inline] #[inline]
fn poll_close(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_close(cx) self.project().poll_close(cx)
} }
fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_inbound(
self.inner.poll_inbound(cx) self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
self.project().poll_inbound(cx)
} }
fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_outbound(
self.inner.poll_outbound(cx) self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
self.project().poll_outbound(cx)
} }
fn poll_address_change(&self, cx: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>> { fn poll_address_change(
self.inner.poll_address_change(cx) self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>> {
self.project().poll_address_change(cx)
} }
} }

View File

@ -23,6 +23,7 @@ use crate::{connection::Endpoint, muxing::StreamMuxer};
use futures::prelude::*; use futures::prelude::*;
use multiaddr::Multiaddr; use multiaddr::Multiaddr;
use std::cell::Cell; use std::cell::Cell;
use std::pin::Pin;
use std::{io, task::Context, task::Poll}; use std::{io, task::Context, task::Poll};
/// Implementation of `StreamMuxer` that allows only one substream on top of a connection, /// Implementation of `StreamMuxer` that allows only one substream on top of a connection,
@ -57,31 +58,44 @@ where
type Substream = TSocket; type Substream = TSocket;
type Error = io::Error; type Error = io::Error;
fn poll_inbound(&self, _: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_inbound(
match self.endpoint { self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
let this = self.get_mut();
match this.endpoint {
Endpoint::Dialer => Poll::Pending, Endpoint::Dialer => Poll::Pending,
Endpoint::Listener => match self.inner.replace(None) { Endpoint::Listener => match this.inner.replace(None) {
None => Poll::Pending, None => Poll::Pending,
Some(stream) => Poll::Ready(Ok(stream)), Some(stream) => Poll::Ready(Ok(stream)),
}, },
} }
} }
fn poll_outbound(&self, _: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_outbound(
match self.endpoint { self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
let this = self.get_mut();
match this.endpoint {
Endpoint::Listener => Poll::Pending, Endpoint::Listener => Poll::Pending,
Endpoint::Dialer => match self.inner.replace(None) { Endpoint::Dialer => match this.inner.replace(None) {
None => Poll::Pending, None => Poll::Pending,
Some(stream) => Poll::Ready(Ok(stream)), Some(stream) => Poll::Ready(Ok(stream)),
}, },
} }
} }
fn poll_address_change(&self, _: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>> { fn poll_address_change(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>> {
Poll::Pending Poll::Pending
} }
fn poll_close(&self, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }

View File

@ -26,9 +26,9 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughpu
use futures::future::poll_fn; use futures::future::poll_fn;
use futures::prelude::*; use futures::prelude::*;
use futures::{channel::oneshot, future::join}; use futures::{channel::oneshot, future::join};
use libp2p_core::muxing::StreamMuxerExt;
use libp2p_core::{ use libp2p_core::{
identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, StreamMuxer, identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, Transport,
Transport,
}; };
use libp2p_mplex as mplex; use libp2p_mplex as mplex;
use libp2p_plaintext::PlainText2Config; use libp2p_plaintext::PlainText2Config;
@ -113,10 +113,8 @@ fn run(
addr_sender.take().unwrap().send(listen_addr).unwrap(); addr_sender.take().unwrap().send(listen_addr).unwrap();
} }
transport::TransportEvent::Incoming { upgrade, .. } => { transport::TransportEvent::Incoming { upgrade, .. } => {
let (_peer, conn) = upgrade.await.unwrap(); let (_peer, mut conn) = upgrade.await.unwrap();
let mut s = poll_fn(|cx| conn.poll_inbound(cx)) let mut s = conn.next_inbound().await.expect("unexpected error");
.await
.expect("unexpected error");
let mut buf = vec![0u8; payload_len]; let mut buf = vec![0u8; payload_len];
let mut off = 0; let mut off = 0;
@ -140,8 +138,8 @@ fn run(
// Spawn and block on the sender, i.e. until all data is sent. // Spawn and block on the sender, i.e. until all data is sent.
let sender = async move { let sender = async move {
let addr = addr_receiver.await.unwrap(); let addr = addr_receiver.await.unwrap();
let (_peer, conn) = sender_trans.dial(addr).unwrap().await.unwrap(); let (_peer, mut conn) = sender_trans.dial(addr).unwrap().await.unwrap();
let mut stream = poll_fn(|cx| conn.poll_outbound(cx)).await.unwrap(); let mut stream = conn.next_outbound().await.unwrap();
let mut off = 0; let mut off = 0;
loop { loop {
let n = poll_fn(|cx| Pin::new(&mut stream).poll_write(cx, &payload[off..])) let n = poll_fn(|cx| Pin::new(&mut stream).poll_write(cx, &payload[off..]))

View File

@ -85,25 +85,34 @@ where
type Substream = Substream<C>; type Substream = Substream<C>;
type Error = io::Error; type Error = io::Error;
fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_inbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
self.io self.io
.lock() .lock()
.poll_next_stream(cx) .poll_next_stream(cx)
.map_ok(|stream_id| Substream::new(stream_id, self.io.clone())) .map_ok(|stream_id| Substream::new(stream_id, self.io.clone()))
} }
fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_outbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
self.io self.io
.lock() .lock()
.poll_open_stream(cx) .poll_open_stream(cx)
.map_ok(|stream_id| Substream::new(stream_id, self.io.clone())) .map_ok(|stream_id| Substream::new(stream_id, self.io.clone()))
} }
fn poll_address_change(&self, _: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>> { fn poll_address_change(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>> {
Poll::Pending Poll::Pending
} }
fn poll_close(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.io.lock().poll_close(cx) self.io.lock().poll_close(cx)
} }
} }

View File

@ -18,9 +18,9 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use futures::future::poll_fn;
use futures::{channel::oneshot, prelude::*}; use futures::{channel::oneshot, prelude::*};
use libp2p_core::{upgrade, StreamMuxer, Transport}; use libp2p_core::muxing::StreamMuxerExt;
use libp2p_core::{upgrade, Transport};
use libp2p_tcp::TcpTransport; use libp2p_tcp::TcpTransport;
#[test] #[test]
@ -49,7 +49,7 @@ fn async_write() {
tx.send(addr).unwrap(); tx.send(addr).unwrap();
let client = transport let mut client = transport
.next() .next()
.await .await
.expect("some event") .expect("some event")
@ -59,7 +59,7 @@ fn async_write() {
.await .await
.unwrap(); .unwrap();
let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); let mut outbound = client.next_outbound().await.unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
outbound.read_to_end(&mut buf).await.unwrap(); outbound.read_to_end(&mut buf).await.unwrap();
@ -71,8 +71,9 @@ fn async_write() {
let mut transport = TcpTransport::default() let mut transport = TcpTransport::default()
.and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1));
let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); let mut client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap();
let mut inbound = poll_fn(|cx| client.poll_inbound(cx)).await.unwrap();
let mut inbound = client.next_inbound().await.unwrap();
inbound.write_all(b"hello world").await.unwrap(); inbound.write_all(b"hello world").await.unwrap();
// The test consists in making sure that this flushes the substream. // The test consists in making sure that this flushes the substream.

View File

@ -18,9 +18,9 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use futures::future::poll_fn;
use futures::{channel::oneshot, prelude::*}; use futures::{channel::oneshot, prelude::*};
use libp2p_core::{upgrade, StreamMuxer, Transport}; use libp2p_core::muxing::StreamMuxerExt;
use libp2p_core::{upgrade, Transport};
use libp2p_tcp::TcpTransport; use libp2p_tcp::TcpTransport;
#[test] #[test]
@ -49,7 +49,7 @@ fn client_to_server_outbound() {
tx.send(addr).unwrap(); tx.send(addr).unwrap();
let client = transport let mut client = transport
.next() .next()
.await .await
.expect("some event") .expect("some event")
@ -59,7 +59,7 @@ fn client_to_server_outbound() {
.await .await
.unwrap(); .unwrap();
let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); let mut outbound = client.next_outbound().await.unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
outbound.read_to_end(&mut buf).await.unwrap(); outbound.read_to_end(&mut buf).await.unwrap();
@ -72,8 +72,8 @@ fn client_to_server_outbound() {
.and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1))
.boxed(); .boxed();
let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); let mut client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap();
let mut inbound = poll_fn(|cx| client.poll_inbound(cx)).await.unwrap(); let mut inbound = client.next_inbound().await.unwrap();
inbound.write_all(b"hello world").await.unwrap(); inbound.write_all(b"hello world").await.unwrap();
inbound.close().await.unwrap(); inbound.close().await.unwrap();
@ -107,7 +107,7 @@ fn client_to_server_inbound() {
tx.send(addr).unwrap(); tx.send(addr).unwrap();
let client = transport let mut client = transport
.next() .next()
.await .await
.expect("some event") .expect("some event")
@ -117,7 +117,7 @@ fn client_to_server_inbound() {
.await .await
.unwrap(); .unwrap();
let mut inbound = poll_fn(|cx| client.poll_inbound(cx)).await.unwrap(); let mut inbound = client.next_inbound().await.unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
inbound.read_to_end(&mut buf).await.unwrap(); inbound.read_to_end(&mut buf).await.unwrap();
@ -130,9 +130,9 @@ fn client_to_server_inbound() {
.and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1))
.boxed(); .boxed();
let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); let mut client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap();
let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); let mut outbound = client.next_outbound().await.unwrap();
outbound.write_all(b"hello world").await.unwrap(); outbound.write_all(b"hello world").await.unwrap();
outbound.close().await.unwrap(); outbound.close().await.unwrap();
@ -164,7 +164,7 @@ fn protocol_not_match() {
tx.send(addr).unwrap(); tx.send(addr).unwrap();
let client = transport let mut client = transport
.next() .next()
.await .await
.expect("some event") .expect("some event")
@ -174,7 +174,7 @@ fn protocol_not_match() {
.await .await
.unwrap(); .unwrap();
let mut outbound = poll_fn(|cx| client.poll_outbound(cx)).await.unwrap(); let mut outbound = client.next_outbound().await.unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
outbound.read_to_end(&mut buf).await.unwrap(); outbound.read_to_end(&mut buf).await.unwrap();

View File

@ -29,7 +29,6 @@ use futures::{
use libp2p_core::muxing::StreamMuxer; use libp2p_core::muxing::StreamMuxer;
use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use libp2p_core::Multiaddr; use libp2p_core::Multiaddr;
use parking_lot::Mutex;
use std::{ use std::{
fmt, io, iter, mem, fmt, io, iter, mem,
pin::Pin, pin::Pin,
@ -39,7 +38,12 @@ use thiserror::Error;
use yamux::ConnectionError; use yamux::ConnectionError;
/// A Yamux connection. /// A Yamux connection.
pub struct Yamux<S>(Mutex<Inner<S>>); pub struct Yamux<S> {
/// The [`futures::stream::Stream`] of incoming substreams.
incoming: S,
/// Handle to control the connection.
control: yamux::Control,
}
impl<S> fmt::Debug for Yamux<S> { impl<S> fmt::Debug for Yamux<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@ -47,13 +51,6 @@ impl<S> fmt::Debug for Yamux<S> {
} }
} }
struct Inner<S> {
/// The [`futures::stream::Stream`] of incoming substreams.
incoming: S,
/// Handle to control the connection.
control: yamux::Control,
}
/// A token to poll for an outbound substream. /// A token to poll for an outbound substream.
#[derive(Debug)] #[derive(Debug)]
pub struct OpenSubstreamToken(()); pub struct OpenSubstreamToken(());
@ -66,14 +63,14 @@ where
fn new(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { fn new(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self {
let conn = yamux::Connection::new(io, cfg, mode); let conn = yamux::Connection::new(io, cfg, mode);
let ctrl = conn.control(); let ctrl = conn.control();
let inner = Inner {
Yamux {
incoming: Incoming { incoming: Incoming {
stream: yamux::into_stream(conn).err_into().boxed(), stream: yamux::into_stream(conn).err_into().boxed(),
_marker: std::marker::PhantomData, _marker: std::marker::PhantomData,
}, },
control: ctrl, control: ctrl,
}; }
Yamux(Mutex::new(inner))
} }
} }
@ -85,14 +82,14 @@ where
fn local(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { fn local(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self {
let conn = yamux::Connection::new(io, cfg, mode); let conn = yamux::Connection::new(io, cfg, mode);
let ctrl = conn.control(); let ctrl = conn.control();
let inner = Inner {
Yamux {
incoming: LocalIncoming { incoming: LocalIncoming {
stream: yamux::into_stream(conn).err_into().boxed_local(), stream: yamux::into_stream(conn).err_into().boxed_local(),
_marker: std::marker::PhantomData, _marker: std::marker::PhantomData,
}, },
control: ctrl, control: ctrl,
}; }
Yamux(Mutex::new(inner))
} }
} }
@ -105,41 +102,44 @@ where
type Substream = yamux::Stream; type Substream = yamux::Stream;
type Error = YamuxError; type Error = YamuxError;
fn poll_inbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_inbound(
self.0 mut self: Pin<&mut Self>,
.lock() cx: &mut Context<'_>,
.incoming ) -> Poll<Result<Self::Substream, Self::Error>> {
.poll_next_unpin(cx) self.incoming.poll_next_unpin(cx).map(|maybe_stream| {
.map(|maybe_stream| { let stream = maybe_stream
let stream = maybe_stream .transpose()?
.transpose()? .ok_or(YamuxError(ConnectionError::Closed))?;
.ok_or(YamuxError(ConnectionError::Closed))?;
Ok(stream) Ok(stream)
}) })
} }
fn poll_outbound(&self, cx: &mut Context<'_>) -> Poll<Result<Self::Substream, Self::Error>> { fn poll_outbound(
Pin::new(&mut self.0.lock().control) mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Pin::new(&mut self.control)
.poll_open_stream(cx) .poll_open_stream(cx)
.map_err(YamuxError) .map_err(YamuxError)
} }
fn poll_address_change(&self, _: &mut Context<'_>) -> Poll<Result<Multiaddr, Self::Error>> { fn poll_address_change(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Multiaddr, Self::Error>> {
Poll::Pending Poll::Pending
} }
fn poll_close(&self, c: &mut Context<'_>) -> Poll<YamuxResult<()>> { fn poll_close(mut self: Pin<&mut Self>, c: &mut Context<'_>) -> Poll<YamuxResult<()>> {
let mut inner = self.0.lock(); if let Poll::Ready(()) = Pin::new(&mut self.control)
if let Poll::Ready(()) = Pin::new(&mut inner.control)
.poll_close(c) .poll_close(c)
.map_err(YamuxError)? .map_err(YamuxError)?
{ {
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
while let Poll::Ready(maybe_inbound_stream) = inner.incoming.poll_next_unpin(c)? { while let Poll::Ready(maybe_inbound_stream) = self.incoming.poll_next_unpin(c)? {
match maybe_inbound_stream { match maybe_inbound_stream {
Some(inbound_stream) => mem::drop(inbound_stream), Some(inbound_stream) => mem::drop(inbound_stream),
None => return Poll::Ready(Ok(())), None => return Poll::Ready(Ok(())),

View File

@ -32,13 +32,12 @@ pub use pool::{EstablishedConnection, PendingConnection};
use crate::handler::ConnectionHandler; use crate::handler::ConnectionHandler;
use crate::IntoConnectionHandler; use crate::IntoConnectionHandler;
use futures::future::poll_fn;
use handler_wrapper::HandlerWrapper; use handler_wrapper::HandlerWrapper;
use libp2p_core::connection::ConnectedPoint; use libp2p_core::connection::ConnectedPoint;
use libp2p_core::multiaddr::Multiaddr; use libp2p_core::multiaddr::Multiaddr;
use libp2p_core::muxing::StreamMuxerBox; use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt};
use libp2p_core::upgrade;
use libp2p_core::PeerId; use libp2p_core::PeerId;
use libp2p_core::{upgrade, StreamMuxer};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future; use std::future::Future;
use std::{error::Error, fmt, io, pin::Pin, task::Context, task::Poll}; use std::{error::Error, fmt, io, pin::Pin, task::Context, task::Poll};
@ -132,10 +131,7 @@ where
/// Begins an orderly shutdown of the connection, returning the connection /// Begins an orderly shutdown of the connection, returning the connection
/// handler and a `Future` that resolves when connection shutdown is complete. /// handler and a `Future` that resolves when connection shutdown is complete.
pub fn close(self) -> (THandler, impl Future<Output = io::Result<()>>) { pub fn close(self) -> (THandler, impl Future<Output = io::Result<()>>) {
( (self.handler.into_connection_handler(), self.muxing.close())
self.handler.into_connection_handler(),
poll_fn(move |cx| self.muxing.poll_close(cx)),
)
} }
/// Polls the handler and the substream, forwarding events from the former to the latter and /// Polls the handler and the substream, forwarding events from the former to the latter and
@ -158,7 +154,7 @@ where
} }
if !self.open_info.is_empty() { if !self.open_info.is_empty() {
if let Poll::Ready(substream) = self.muxing.poll_outbound(cx)? { if let Poll::Ready(substream) = self.muxing.poll_outbound_unpin(cx)? {
let user_data = self let user_data = self
.open_info .open_info
.pop_front() .pop_front()
@ -169,13 +165,13 @@ where
} }
} }
if let Poll::Ready(substream) = self.muxing.poll_inbound(cx)? { if let Poll::Ready(substream) = self.muxing.poll_inbound_unpin(cx)? {
self.handler self.handler
.inject_substream(substream, SubstreamEndpoint::Listener); .inject_substream(substream, SubstreamEndpoint::Listener);
continue; // Go back to the top, handler can potentially make progress again. continue; // Go back to the top, handler can potentially make progress again.
} }
if let Poll::Ready(address) = self.muxing.poll_address_change(cx)? { if let Poll::Ready(address) = self.muxing.poll_address_change_unpin(cx)? {
self.handler.inject_address_change(&address); self.handler.inject_address_change(&address);
return Poll::Ready(Ok(Event::AddressChange(address))); return Poll::Ready(Ok(Event::AddressChange(address)));
} }

View File

@ -38,7 +38,7 @@ use futures::{
stream::FuturesUnordered, stream::FuturesUnordered,
}; };
use libp2p_core::connection::{ConnectionId, Endpoint, PendingPoint}; use libp2p_core::connection::{ConnectionId, Endpoint, PendingPoint};
use libp2p_core::muxing::{StreamMuxer, StreamMuxerBox}; use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt};
use std::{ use std::{
collections::{hash_map, HashMap}, collections::{hash_map, HashMap},
convert::TryFrom as _, convert::TryFrom as _,
@ -604,7 +604,7 @@ where
match event { match event {
task::PendingConnectionEvent::ConnectionEstablished { task::PendingConnectionEvent::ConnectionEstablished {
id, id,
output: (obtained_peer_id, muxer), output: (obtained_peer_id, mut muxer),
outgoing, outgoing,
} => { } => {
let PendingConnectionInfo { let PendingConnectionInfo {
@ -692,7 +692,7 @@ where
if let Err(error) = error { if let Err(error) = error {
self.spawn( self.spawn(
poll_fn(move |cx| { poll_fn(move |cx| {
if let Err(e) = ready!(muxer.poll_close(cx)) { if let Err(e) = ready!(muxer.poll_close_unpin(cx)) {
log::debug!( log::debug!(
"Failed to close connection {:?} to peer {}: {:?}", "Failed to close connection {:?} to peer {}: {:?}",
id, id,