diff --git a/Cargo.toml b/Cargo.toml index d2f361cb..cf4ca0d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -126,6 +126,7 @@ members = [ "misc/metrics", "misc/multistream-select", "misc/keygen", + "misc/prost-codec", "muxers/mplex", "muxers/yamux", "protocols/dcutr", diff --git a/misc/prost-codec/Cargo.toml b/misc/prost-codec/Cargo.toml new file mode 100644 index 00000000..02969651 --- /dev/null +++ b/misc/prost-codec/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "prost-codec" +edition = "2021" +rust-version = "1.56.1" +description = "Asynchronous de-/encoding of Protobuf structs using asynchronous-codec, unsigned-varint and prost." +version = "0.1.0" +authors = ["Max Inden "] +license = "MIT" +repository = "https://github.com/libp2p/rust-libp2p" +keywords = ["networking"] +categories = ["asynchronous"] + +[dependencies] +asynchronous-codec = { version = "0.6" } +bytes = { version = "1" } +prost = "0.10" +thiserror = "1.0" +unsigned-varint = { version = "0.7", features = ["asynchronous_codec"] } + +[dev-dependencies] +prost-build = "0.10" diff --git a/misc/prost-codec/src/lib.rs b/misc/prost-codec/src/lib.rs new file mode 100644 index 00000000..32b8c9b9 --- /dev/null +++ b/misc/prost-codec/src/lib.rs @@ -0,0 +1,81 @@ +use asynchronous_codec::{Decoder, Encoder}; +use bytes::BytesMut; +use prost::Message; +use std::io::Cursor; +use std::marker::PhantomData; +use thiserror::Error; +use unsigned_varint::codec::UviBytes; + +/// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`] +/// to prefix messages with their length and uses [`prost`] and a provided +/// `struct` implementing [`Message`] to do the encoding. +pub struct Codec { + uvi: UviBytes, + phantom: PhantomData<(In, Out)>, +} + +impl Codec { + /// Create new [`Codec`]. + /// + /// Parameter `max_message_len_bytes` determines the maximum length of the + /// Protobuf message. The limit does not include the bytes needed for the + /// [`unsigned_varint`]. + pub fn new(max_message_len_bytes: usize) -> Self { + let mut uvi = UviBytes::default(); + uvi.set_max_len(max_message_len_bytes); + Self { + uvi, + phantom: PhantomData::default(), + } + } +} + +impl Encoder for Codec { + type Item = In; + type Error = Error; + + fn encode( + &mut self, + item: Self::Item, + dst: &mut asynchronous_codec::BytesMut, + ) -> Result<(), Self::Error> { + let mut encoded_msg = BytesMut::new(); + item.encode(&mut encoded_msg) + .expect("BytesMut to have sufficient capacity."); + self.uvi + .encode(encoded_msg.freeze(), dst) + .map_err(|e| e.into()) + } +} + +impl Decoder for Codec { + type Item = Out; + type Error = Error; + + fn decode( + &mut self, + src: &mut asynchronous_codec::BytesMut, + ) -> Result, Self::Error> { + Ok(self + .uvi + .decode(src)? + .map(|msg| Message::decode(Cursor::new(msg))) + .transpose()?) + } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("Failed to decode response: {0}.")] + Decode( + #[from] + #[source] + prost::DecodeError, + ), + #[error("Io error {0}")] + Io( + #[from] + #[source] + std::io::Error, + ), +} diff --git a/protocols/dcutr/Cargo.toml b/protocols/dcutr/Cargo.toml index 6b8a84e3..2dfc08e4 100644 --- a/protocols/dcutr/Cargo.toml +++ b/protocols/dcutr/Cargo.toml @@ -20,9 +20,9 @@ instant = "0.1.11" libp2p-core = { version = "0.33.0", path = "../../core" } libp2p-swarm = { version = "0.36.0", path = "../../swarm" } log = "0.4" +prost-codec = { version = "0.1", path = "../../misc/prost-codec" } prost = "0.10" thiserror = "1.0" -unsigned-varint = { version = "0.7", features = ["asynchronous_codec"] } void = "1" [build-dependencies] diff --git a/protocols/dcutr/src/protocol.rs b/protocols/dcutr/src/protocol.rs index a69a27f7..d2b8b39a 100644 --- a/protocols/dcutr/src/protocol.rs +++ b/protocols/dcutr/src/protocol.rs @@ -18,8 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -mod codec; pub mod inbound; pub mod outbound; const PROTOCOL_NAME: &[u8; 13] = b"/libp2p/dcutr"; + +const MAX_MESSAGE_SIZE_BYTES: usize = 4096; diff --git a/protocols/dcutr/src/protocol/codec.rs b/protocols/dcutr/src/protocol/codec.rs deleted file mode 100644 index 9706d756..00000000 --- a/protocols/dcutr/src/protocol/codec.rs +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2022 Protocol Labs. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use crate::message_proto; -use bytes::BytesMut; -use prost::Message; -use std::io::Cursor; -use thiserror::Error; -use unsigned_varint::codec::UviBytes; - -const MAX_MESSAGE_SIZE_BYTES: usize = 4096; - -pub struct Codec(UviBytes); - -impl Codec { - pub fn new() -> Self { - let mut codec = UviBytes::default(); - codec.set_max_len(MAX_MESSAGE_SIZE_BYTES); - Self(codec) - } -} - -impl asynchronous_codec::Encoder for Codec { - type Item = message_proto::HolePunch; - type Error = Error; - - fn encode( - &mut self, - item: Self::Item, - dst: &mut asynchronous_codec::BytesMut, - ) -> Result<(), Self::Error> { - let mut encoded_msg = BytesMut::new(); - item.encode(&mut encoded_msg) - .expect("BytesMut to have sufficient capacity."); - self.0 - .encode(encoded_msg.freeze(), dst) - .map_err(|e| e.into()) - } -} - -impl asynchronous_codec::Decoder for Codec { - type Item = message_proto::HolePunch; - type Error = Error; - - fn decode( - &mut self, - src: &mut asynchronous_codec::BytesMut, - ) -> Result, Self::Error> { - Ok(self - .0 - .decode(src)? - .map(|msg| message_proto::HolePunch::decode(Cursor::new(msg))) - .transpose()?) - } -} - -#[derive(Debug, Error)] -pub enum Error { - #[error("Failed to decode response: {0}.")] - Decode( - #[from] - #[source] - prost::DecodeError, - ), - #[error("Io error {0}")] - Io( - #[from] - #[source] - std::io::Error, - ), -} diff --git a/protocols/dcutr/src/protocol/inbound.rs b/protocols/dcutr/src/protocol/inbound.rs index c05e1072..cfb28ca5 100644 --- a/protocols/dcutr/src/protocol/inbound.rs +++ b/protocols/dcutr/src/protocol/inbound.rs @@ -44,17 +44,14 @@ impl upgrade::InboundUpgrade for Upgrade { type Future = BoxFuture<'static, Result>; fn upgrade_inbound(self, substream: NegotiatedSubstream, _: Self::Info) -> Self::Future { - let mut substream = Framed::new(substream, super::codec::Codec::new()); + let mut substream = Framed::new( + substream, + prost_codec::Codec::new(super::MAX_MESSAGE_SIZE_BYTES), + ); async move { let HolePunch { r#type, obs_addrs } = - substream - .next() - .await - .ok_or(super::codec::Error::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + substream.next().await.ok_or(UpgradeError::StreamClosed)??; let obs_addrs = if obs_addrs.is_empty() { return Err(UpgradeError::NoAddresses); @@ -88,7 +85,7 @@ impl upgrade::InboundUpgrade for Upgrade { } pub struct PendingConnect { - substream: Framed, + substream: Framed>, remote_obs_addrs: Vec, } @@ -103,14 +100,11 @@ impl PendingConnect { }; self.substream.send(msg).await?; - let HolePunch { r#type, .. } = - self.substream - .next() - .await - .ok_or(super::codec::Error::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + let HolePunch { r#type, .. } = self + .substream + .next() + .await + .ok_or(UpgradeError::StreamClosed)??; let r#type = hole_punch::Type::from_i32(r#type).ok_or(UpgradeError::ParseTypeField)?; match r#type { @@ -124,12 +118,14 @@ impl PendingConnect { #[derive(Debug, Error)] pub enum UpgradeError { - #[error("Failed to encode or decode: {0}")] + #[error("Failed to encode or decode")] Codec( #[from] #[source] - super::codec::Error, + prost_codec::Error, ), + #[error("Stream closed")] + StreamClosed, #[error("Expected at least one address in reservation.")] NoAddresses, #[error("Invalid addresses.")] diff --git a/protocols/dcutr/src/protocol/outbound.rs b/protocols/dcutr/src/protocol/outbound.rs index 332554c2..dc55aa41 100644 --- a/protocols/dcutr/src/protocol/outbound.rs +++ b/protocols/dcutr/src/protocol/outbound.rs @@ -54,7 +54,10 @@ impl upgrade::OutboundUpgrade for Upgrade { type Future = BoxFuture<'static, Result>; fn upgrade_outbound(self, substream: NegotiatedSubstream, _: Self::Info) -> Self::Future { - let mut substream = Framed::new(substream, super::codec::Codec::new()); + let mut substream = Framed::new( + substream, + prost_codec::Codec::new(super::MAX_MESSAGE_SIZE_BYTES), + ); let msg = HolePunch { r#type: hole_punch::Type::Connect.into(), @@ -67,13 +70,7 @@ impl upgrade::OutboundUpgrade for Upgrade { let sent_time = Instant::now(); let HolePunch { r#type, obs_addrs } = - substream - .next() - .await - .ok_or(super::codec::Error::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + substream.next().await.ok_or(UpgradeError::StreamClosed)??; let rtt = sent_time.elapsed(); @@ -123,8 +120,10 @@ pub enum UpgradeError { Codec( #[from] #[source] - super::codec::Error, + prost_codec::Error, ), + #[error("Stream closed")] + StreamClosed, #[error("Expected 'status' field to be set.")] MissingStatusField, #[error("Expected 'reservation' field to be set.")] diff --git a/protocols/identify/CHANGELOG.md b/protocols/identify/CHANGELOG.md index 4a0356a7..3a13b4a9 100644 --- a/protocols/identify/CHANGELOG.md +++ b/protocols/identify/CHANGELOG.md @@ -4,6 +4,9 @@ - Update to `libp2p-swarm` `v0.36.0`. +- Expose explicits errors via `UpgradeError` instead of generic `io::Error`. See [PR 2630]. + +[PR 2630]: https://github.com/libp2p/rust-libp2p/pull/2630 # 0.35.0 - Update to `libp2p-swarm` `v0.35.0`. diff --git a/protocols/identify/Cargo.toml b/protocols/identify/Cargo.toml index 69945251..d16a35b9 100644 --- a/protocols/identify/Cargo.toml +++ b/protocols/identify/Cargo.toml @@ -11,14 +11,17 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] +asynchronous-codec = "0.6" futures = "0.3.1" futures-timer = "3.0.2" libp2p-core = { version = "0.33.0", path = "../../core", default-features = false } libp2p-swarm = { version = "0.36.0", path = "../../swarm" } log = "0.4.1" lru = "0.7.2" +prost-codec = { version = "0.1", path = "../../misc/prost-codec" } prost = "0.10" smallvec = "1.6.1" +thiserror = "1.0" [dev-dependencies] async-std = "1.6.2" diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index f22c0bd3..85f83f10 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -19,14 +19,13 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol::{ - IdentifyInfo, IdentifyProtocol, IdentifyPushProtocol, InboundPush, OutboundPush, ReplySubstream, + IdentifyInfo, IdentifyProtocol, IdentifyPushProtocol, InboundPush, OutboundPush, + ReplySubstream, UpgradeError, }; use futures::prelude::*; use futures_timer::Delay; use libp2p_core::either::{EitherError, EitherOutput}; -use libp2p_core::upgrade::{ - EitherUpgrade, InboundUpgrade, OutboundUpgrade, SelectUpgrade, UpgradeError, -}; +use libp2p_core::upgrade::{EitherUpgrade, InboundUpgrade, OutboundUpgrade, SelectUpgrade}; use libp2p_swarm::{ ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive, NegotiatedSubstream, SubstreamProtocol, @@ -70,7 +69,7 @@ pub enum IdentifyHandlerEvent { /// We received a request for identification. Identify(ReplySubstream), /// Failed to identify the remote. - IdentificationError(ConnectionHandlerUpgrErr), + IdentificationError(ConnectionHandlerUpgrErr), } /// Identifying information of the local node that is pushed to a remote. @@ -155,6 +154,8 @@ impl ConnectionHandler for IdentifyHandler { >::Error, >, ) { + use libp2p_core::upgrade::UpgradeError; + let err = err.map_upgrade_err(|e| match e { UpgradeError::Select(e) => UpgradeError::Select(e), UpgradeError::Apply(EitherError::A(ioe)) => UpgradeError::Apply(ioe), diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 20daddcc..5f5e4683 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -19,12 +19,11 @@ // DEALINGS IN THE SOFTWARE. use crate::handler::{IdentifyHandler, IdentifyHandlerEvent, IdentifyPush}; -use crate::protocol::{IdentifyInfo, ReplySubstream}; +use crate::protocol::{IdentifyInfo, ReplySubstream, UpgradeError}; use futures::prelude::*; use libp2p_core::{ connection::{ConnectionId, ListenerId}, multiaddr::Protocol, - upgrade::UpgradeError, ConnectedPoint, Multiaddr, PeerId, PublicKey, }; use libp2p_swarm::{ @@ -35,7 +34,6 @@ use libp2p_swarm::{ use lru::LruCache; use std::{ collections::{HashMap, HashSet, VecDeque}, - io, iter::FromIterator, pin::Pin, task::Context, @@ -75,7 +73,7 @@ enum Reply { /// The reply is being sent. Sending { peer: PeerId, - io: Pin> + Send>>, + io: Pin> + Send>>, }, } @@ -429,9 +427,9 @@ impl NetworkBehaviour for Identify { Poll::Ready(Err(err)) => { let event = IdentifyEvent::Error { peer_id: peer, - error: ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Apply( - err, - )), + error: ConnectionHandlerUpgrErr::Upgrade( + libp2p_core::upgrade::UpgradeError::Apply(err), + ), }; return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } @@ -482,7 +480,7 @@ pub enum IdentifyEvent { /// The peer with whom the error originated. peer_id: PeerId, /// The error that occurred. - error: ConnectionHandlerUpgrErr, + error: ConnectionHandlerUpgrErr, }, } diff --git a/protocols/identify/src/lib.rs b/protocols/identify/src/lib.rs index db1c5536..f5de8f7a 100644 --- a/protocols/identify/src/lib.rs +++ b/protocols/identify/src/lib.rs @@ -45,7 +45,7 @@ //! [`IdentifyInfo`]: self::IdentifyInfo pub use self::identify::{Identify, IdentifyConfig, IdentifyEvent}; -pub use self::protocol::IdentifyInfo; +pub use self::protocol::{IdentifyInfo, UpgradeError}; mod handler; mod identify; diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index d54f5f4f..439412d5 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -19,15 +19,19 @@ // DEALINGS IN THE SOFTWARE. use crate::structs_proto; +use asynchronous_codec::{FramedRead, FramedWrite}; use futures::prelude::*; use libp2p_core::{ - upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo}, + identity, multiaddr, + upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, Multiaddr, PublicKey, }; -use log::{debug, trace}; -use prost::Message; +use log::trace; use std::convert::TryFrom; use std::{fmt, io, iter, pin::Pin}; +use thiserror::Error; + +const MAX_MESSAGE_SIZE_BYTES: usize = 4096; /// Substream upgrade protocol for `/ipfs/id/1.0.0`. #[derive(Debug, Clone)] @@ -89,8 +93,8 @@ where /// /// Consumes the substream, returning a future that resolves /// when the reply has been sent on the underlying connection. - pub async fn send(self, info: IdentifyInfo) -> io::Result<()> { - send(self.inner, info).await + pub async fn send(self, info: IdentifyInfo) -> Result<(), UpgradeError> { + send(self.inner, info).await.map_err(Into::into) } } @@ -105,8 +109,8 @@ impl UpgradeInfo for IdentifyProtocol { impl InboundUpgrade for IdentifyProtocol { type Output = ReplySubstream; - type Error = io::Error; - type Future = future::Ready>; + type Error = UpgradeError; + type Future = future::Ready>; fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future { future::ok(ReplySubstream { inner: socket }) @@ -118,7 +122,7 @@ where C: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = IdentifyInfo; - type Error = io::Error; + type Error = UpgradeError; type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { @@ -140,7 +144,7 @@ where C: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = IdentifyInfo; - type Error = io::Error; + type Error = UpgradeError; type Future = Pin> + Send>>; fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future { @@ -153,7 +157,7 @@ where C: AsyncWrite + Unpin + Send + 'static, { type Output = (); - type Error = io::Error; + type Error = UpgradeError; type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { @@ -161,7 +165,7 @@ where } } -async fn send(mut io: T, info: IdentifyInfo) -> io::Result<()> +async fn send(io: T, info: IdentifyInfo) -> Result<(), UpgradeError> where T: AsyncWrite + Unpin, { @@ -184,77 +188,99 @@ where protocols: info.protocols, }; - let mut bytes = Vec::with_capacity(message.encoded_len()); - message - .encode(&mut bytes) - .expect("Vec provides capacity as needed"); + let mut framed_io = FramedWrite::new( + io, + prost_codec::Codec::::new(MAX_MESSAGE_SIZE_BYTES), + ); - upgrade::write_length_prefixed(&mut io, bytes).await?; - io.close().await?; + framed_io.send(message).await?; + framed_io.close().await?; Ok(()) } -async fn recv(mut socket: T) -> io::Result +async fn recv(mut socket: T) -> Result where T: AsyncRead + AsyncWrite + Unpin, { socket.close().await?; - let msg = upgrade::read_length_prefixed(&mut socket, 4096) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) - .await?; - - let info = match parse_proto_msg(msg) { - Ok(v) => v, - Err(err) => { - debug!("Invalid message: {:?}", err); - return Err(err); - } - }; + let info = FramedRead::new( + socket, + prost_codec::Codec::::new(MAX_MESSAGE_SIZE_BYTES), + ) + .next() + .await + .ok_or(UpgradeError::StreamClosed)?? + .try_into()?; trace!("Received: {:?}", info); Ok(info) } -/// Turns a protobuf message into an `IdentifyInfo`. -fn parse_proto_msg(msg: impl AsRef<[u8]>) -> Result { - match structs_proto::Identify::decode(msg.as_ref()) { - Ok(msg) => { - fn parse_multiaddr(bytes: Vec) -> Result { - Multiaddr::try_from(bytes) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) - } +impl TryFrom for IdentifyInfo { + type Error = UpgradeError; - let listen_addrs = { - let mut addrs = Vec::new(); - for addr in msg.listen_addrs.into_iter() { - addrs.push(parse_multiaddr(addr)?); - } - addrs - }; - - let public_key = PublicKey::from_protobuf_encoding(&msg.public_key.unwrap_or_default()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - let observed_addr = parse_multiaddr(msg.observed_addr.unwrap_or_default())?; - let info = IdentifyInfo { - public_key, - protocol_version: msg.protocol_version.unwrap_or_default(), - agent_version: msg.agent_version.unwrap_or_default(), - listen_addrs, - protocols: msg.protocols, - observed_addr, - }; - - Ok(info) + fn try_from(msg: structs_proto::Identify) -> Result { + fn parse_multiaddr(bytes: Vec) -> Result { + Multiaddr::try_from(bytes) } - Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)), + let listen_addrs = { + let mut addrs = Vec::new(); + for addr in msg.listen_addrs.into_iter() { + addrs.push(parse_multiaddr(addr)?); + } + addrs + }; + + let public_key = PublicKey::from_protobuf_encoding(&msg.public_key.unwrap_or_default())?; + + let observed_addr = parse_multiaddr(msg.observed_addr.unwrap_or_default())?; + let info = IdentifyInfo { + public_key, + protocol_version: msg.protocol_version.unwrap_or_default(), + agent_version: msg.agent_version.unwrap_or_default(), + listen_addrs, + protocols: msg.protocols, + observed_addr, + }; + + Ok(info) } } +#[derive(Debug, Error)] +pub enum UpgradeError { + #[error("Failed to encode or decode")] + Codec( + #[from] + #[source] + prost_codec::Error, + ), + #[error("I/O interaction failed")] + Io( + #[from] + #[source] + io::Error, + ), + #[error("Stream closed")] + StreamClosed, + #[error("Failed decoding multiaddr")] + Multiaddr( + #[from] + #[source] + multiaddr::Error, + ), + #[error("Failed decoding public key")] + PublicKey( + #[from] + #[source] + identity::error::DecodingError, + ), +} + #[cfg(test)] mod tests { use super::*;