misc/prost-codec: Introduce codec for varint prefixed Protobuf messages (#2630)

Extracts the Protobuf en-/decoding pattern into its separate crate
and applies it to `libp2p-identify`.
This commit is contained in:
Max Inden 2022-05-05 18:28:47 +02:00 committed by GitHub
parent 3cfbf89a3a
commit bbd2f8f009
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 234 additions and 192 deletions

View File

@ -126,6 +126,7 @@ members = [
"misc/metrics",
"misc/multistream-select",
"misc/keygen",
"misc/prost-codec",
"muxers/mplex",
"muxers/yamux",
"protocols/dcutr",

View File

@ -0,0 +1,21 @@
[package]
name = "prost-codec"
edition = "2021"
rust-version = "1.56.1"
description = "Asynchronous de-/encoding of Protobuf structs using asynchronous-codec, unsigned-varint and prost."
version = "0.1.0"
authors = ["Max Inden <mail@max-inden.de>"]
license = "MIT"
repository = "https://github.com/libp2p/rust-libp2p"
keywords = ["networking"]
categories = ["asynchronous"]
[dependencies]
asynchronous-codec = { version = "0.6" }
bytes = { version = "1" }
prost = "0.10"
thiserror = "1.0"
unsigned-varint = { version = "0.7", features = ["asynchronous_codec"] }
[dev-dependencies]
prost-build = "0.10"

View File

@ -0,0 +1,81 @@
use asynchronous_codec::{Decoder, Encoder};
use bytes::BytesMut;
use prost::Message;
use std::io::Cursor;
use std::marker::PhantomData;
use thiserror::Error;
use unsigned_varint::codec::UviBytes;
/// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`]
/// to prefix messages with their length and uses [`prost`] and a provided
/// `struct` implementing [`Message`] to do the encoding.
pub struct Codec<In, Out = In> {
uvi: UviBytes,
phantom: PhantomData<(In, Out)>,
}
impl<In, Out> Codec<In, Out> {
/// Create new [`Codec`].
///
/// Parameter `max_message_len_bytes` determines the maximum length of the
/// Protobuf message. The limit does not include the bytes needed for the
/// [`unsigned_varint`].
pub fn new(max_message_len_bytes: usize) -> Self {
let mut uvi = UviBytes::default();
uvi.set_max_len(max_message_len_bytes);
Self {
uvi,
phantom: PhantomData::default(),
}
}
}
impl<In: Message, Out> Encoder for Codec<In, Out> {
type Item = In;
type Error = Error;
fn encode(
&mut self,
item: Self::Item,
dst: &mut asynchronous_codec::BytesMut,
) -> Result<(), Self::Error> {
let mut encoded_msg = BytesMut::new();
item.encode(&mut encoded_msg)
.expect("BytesMut to have sufficient capacity.");
self.uvi
.encode(encoded_msg.freeze(), dst)
.map_err(|e| e.into())
}
}
impl<In, Out: Message + Default> Decoder for Codec<In, Out> {
type Item = Out;
type Error = Error;
fn decode(
&mut self,
src: &mut asynchronous_codec::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
Ok(self
.uvi
.decode(src)?
.map(|msg| Message::decode(Cursor::new(msg)))
.transpose()?)
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("Failed to decode response: {0}.")]
Decode(
#[from]
#[source]
prost::DecodeError,
),
#[error("Io error {0}")]
Io(
#[from]
#[source]
std::io::Error,
),
}

View File

@ -20,9 +20,9 @@ instant = "0.1.11"
libp2p-core = { version = "0.33.0", path = "../../core" }
libp2p-swarm = { version = "0.36.0", path = "../../swarm" }
log = "0.4"
prost-codec = { version = "0.1", path = "../../misc/prost-codec" }
prost = "0.10"
thiserror = "1.0"
unsigned-varint = { version = "0.7", features = ["asynchronous_codec"] }
void = "1"
[build-dependencies]

View File

@ -18,8 +18,9 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
mod codec;
pub mod inbound;
pub mod outbound;
const PROTOCOL_NAME: &[u8; 13] = b"/libp2p/dcutr";
const MAX_MESSAGE_SIZE_BYTES: usize = 4096;

View File

@ -1,88 +0,0 @@
// Copyright 2022 Protocol Labs.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::message_proto;
use bytes::BytesMut;
use prost::Message;
use std::io::Cursor;
use thiserror::Error;
use unsigned_varint::codec::UviBytes;
const MAX_MESSAGE_SIZE_BYTES: usize = 4096;
pub struct Codec(UviBytes);
impl Codec {
pub fn new() -> Self {
let mut codec = UviBytes::default();
codec.set_max_len(MAX_MESSAGE_SIZE_BYTES);
Self(codec)
}
}
impl asynchronous_codec::Encoder for Codec {
type Item = message_proto::HolePunch;
type Error = Error;
fn encode(
&mut self,
item: Self::Item,
dst: &mut asynchronous_codec::BytesMut,
) -> Result<(), Self::Error> {
let mut encoded_msg = BytesMut::new();
item.encode(&mut encoded_msg)
.expect("BytesMut to have sufficient capacity.");
self.0
.encode(encoded_msg.freeze(), dst)
.map_err(|e| e.into())
}
}
impl asynchronous_codec::Decoder for Codec {
type Item = message_proto::HolePunch;
type Error = Error;
fn decode(
&mut self,
src: &mut asynchronous_codec::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
Ok(self
.0
.decode(src)?
.map(|msg| message_proto::HolePunch::decode(Cursor::new(msg)))
.transpose()?)
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("Failed to decode response: {0}.")]
Decode(
#[from]
#[source]
prost::DecodeError,
),
#[error("Io error {0}")]
Io(
#[from]
#[source]
std::io::Error,
),
}

View File

@ -44,17 +44,14 @@ impl upgrade::InboundUpgrade<NegotiatedSubstream> for Upgrade {
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_inbound(self, substream: NegotiatedSubstream, _: Self::Info) -> Self::Future {
let mut substream = Framed::new(substream, super::codec::Codec::new());
let mut substream = Framed::new(
substream,
prost_codec::Codec::new(super::MAX_MESSAGE_SIZE_BYTES),
);
async move {
let HolePunch { r#type, obs_addrs } =
substream
.next()
.await
.ok_or(super::codec::Error::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"",
)))??;
substream.next().await.ok_or(UpgradeError::StreamClosed)??;
let obs_addrs = if obs_addrs.is_empty() {
return Err(UpgradeError::NoAddresses);
@ -88,7 +85,7 @@ impl upgrade::InboundUpgrade<NegotiatedSubstream> for Upgrade {
}
pub struct PendingConnect {
substream: Framed<NegotiatedSubstream, super::codec::Codec>,
substream: Framed<NegotiatedSubstream, prost_codec::Codec<HolePunch>>,
remote_obs_addrs: Vec<Multiaddr>,
}
@ -103,14 +100,11 @@ impl PendingConnect {
};
self.substream.send(msg).await?;
let HolePunch { r#type, .. } =
self.substream
.next()
.await
.ok_or(super::codec::Error::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"",
)))??;
let HolePunch { r#type, .. } = self
.substream
.next()
.await
.ok_or(UpgradeError::StreamClosed)??;
let r#type = hole_punch::Type::from_i32(r#type).ok_or(UpgradeError::ParseTypeField)?;
match r#type {
@ -124,12 +118,14 @@ impl PendingConnect {
#[derive(Debug, Error)]
pub enum UpgradeError {
#[error("Failed to encode or decode: {0}")]
#[error("Failed to encode or decode")]
Codec(
#[from]
#[source]
super::codec::Error,
prost_codec::Error,
),
#[error("Stream closed")]
StreamClosed,
#[error("Expected at least one address in reservation.")]
NoAddresses,
#[error("Invalid addresses.")]

View File

@ -54,7 +54,10 @@ impl upgrade::OutboundUpgrade<NegotiatedSubstream> for Upgrade {
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_outbound(self, substream: NegotiatedSubstream, _: Self::Info) -> Self::Future {
let mut substream = Framed::new(substream, super::codec::Codec::new());
let mut substream = Framed::new(
substream,
prost_codec::Codec::new(super::MAX_MESSAGE_SIZE_BYTES),
);
let msg = HolePunch {
r#type: hole_punch::Type::Connect.into(),
@ -67,13 +70,7 @@ impl upgrade::OutboundUpgrade<NegotiatedSubstream> for Upgrade {
let sent_time = Instant::now();
let HolePunch { r#type, obs_addrs } =
substream
.next()
.await
.ok_or(super::codec::Error::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"",
)))??;
substream.next().await.ok_or(UpgradeError::StreamClosed)??;
let rtt = sent_time.elapsed();
@ -123,8 +120,10 @@ pub enum UpgradeError {
Codec(
#[from]
#[source]
super::codec::Error,
prost_codec::Error,
),
#[error("Stream closed")]
StreamClosed,
#[error("Expected 'status' field to be set.")]
MissingStatusField,
#[error("Expected 'reservation' field to be set.")]

View File

@ -4,6 +4,9 @@
- Update to `libp2p-swarm` `v0.36.0`.
- Expose explicits errors via `UpgradeError` instead of generic `io::Error`. See [PR 2630].
[PR 2630]: https://github.com/libp2p/rust-libp2p/pull/2630
# 0.35.0
- Update to `libp2p-swarm` `v0.35.0`.

View File

@ -11,14 +11,17 @@ keywords = ["peer-to-peer", "libp2p", "networking"]
categories = ["network-programming", "asynchronous"]
[dependencies]
asynchronous-codec = "0.6"
futures = "0.3.1"
futures-timer = "3.0.2"
libp2p-core = { version = "0.33.0", path = "../../core", default-features = false }
libp2p-swarm = { version = "0.36.0", path = "../../swarm" }
log = "0.4.1"
lru = "0.7.2"
prost-codec = { version = "0.1", path = "../../misc/prost-codec" }
prost = "0.10"
smallvec = "1.6.1"
thiserror = "1.0"
[dev-dependencies]
async-std = "1.6.2"

View File

@ -19,14 +19,13 @@
// DEALINGS IN THE SOFTWARE.
use crate::protocol::{
IdentifyInfo, IdentifyProtocol, IdentifyPushProtocol, InboundPush, OutboundPush, ReplySubstream,
IdentifyInfo, IdentifyProtocol, IdentifyPushProtocol, InboundPush, OutboundPush,
ReplySubstream, UpgradeError,
};
use futures::prelude::*;
use futures_timer::Delay;
use libp2p_core::either::{EitherError, EitherOutput};
use libp2p_core::upgrade::{
EitherUpgrade, InboundUpgrade, OutboundUpgrade, SelectUpgrade, UpgradeError,
};
use libp2p_core::upgrade::{EitherUpgrade, InboundUpgrade, OutboundUpgrade, SelectUpgrade};
use libp2p_swarm::{
ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive,
NegotiatedSubstream, SubstreamProtocol,
@ -70,7 +69,7 @@ pub enum IdentifyHandlerEvent {
/// We received a request for identification.
Identify(ReplySubstream<NegotiatedSubstream>),
/// Failed to identify the remote.
IdentificationError(ConnectionHandlerUpgrErr<io::Error>),
IdentificationError(ConnectionHandlerUpgrErr<UpgradeError>),
}
/// Identifying information of the local node that is pushed to a remote.
@ -155,6 +154,8 @@ impl ConnectionHandler for IdentifyHandler {
<Self::OutboundProtocol as OutboundUpgrade<NegotiatedSubstream>>::Error,
>,
) {
use libp2p_core::upgrade::UpgradeError;
let err = err.map_upgrade_err(|e| match e {
UpgradeError::Select(e) => UpgradeError::Select(e),
UpgradeError::Apply(EitherError::A(ioe)) => UpgradeError::Apply(ioe),

View File

@ -19,12 +19,11 @@
// DEALINGS IN THE SOFTWARE.
use crate::handler::{IdentifyHandler, IdentifyHandlerEvent, IdentifyPush};
use crate::protocol::{IdentifyInfo, ReplySubstream};
use crate::protocol::{IdentifyInfo, ReplySubstream, UpgradeError};
use futures::prelude::*;
use libp2p_core::{
connection::{ConnectionId, ListenerId},
multiaddr::Protocol,
upgrade::UpgradeError,
ConnectedPoint, Multiaddr, PeerId, PublicKey,
};
use libp2p_swarm::{
@ -35,7 +34,6 @@ use libp2p_swarm::{
use lru::LruCache;
use std::{
collections::{HashMap, HashSet, VecDeque},
io,
iter::FromIterator,
pin::Pin,
task::Context,
@ -75,7 +73,7 @@ enum Reply {
/// The reply is being sent.
Sending {
peer: PeerId,
io: Pin<Box<dyn Future<Output = Result<(), io::Error>> + Send>>,
io: Pin<Box<dyn Future<Output = Result<(), UpgradeError>> + Send>>,
},
}
@ -429,9 +427,9 @@ impl NetworkBehaviour for Identify {
Poll::Ready(Err(err)) => {
let event = IdentifyEvent::Error {
peer_id: peer,
error: ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Apply(
err,
)),
error: ConnectionHandlerUpgrErr::Upgrade(
libp2p_core::upgrade::UpgradeError::Apply(err),
),
};
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event));
}
@ -482,7 +480,7 @@ pub enum IdentifyEvent {
/// The peer with whom the error originated.
peer_id: PeerId,
/// The error that occurred.
error: ConnectionHandlerUpgrErr<io::Error>,
error: ConnectionHandlerUpgrErr<UpgradeError>,
},
}

View File

@ -45,7 +45,7 @@
//! [`IdentifyInfo`]: self::IdentifyInfo
pub use self::identify::{Identify, IdentifyConfig, IdentifyEvent};
pub use self::protocol::IdentifyInfo;
pub use self::protocol::{IdentifyInfo, UpgradeError};
mod handler;
mod identify;

View File

@ -19,15 +19,19 @@
// DEALINGS IN THE SOFTWARE.
use crate::structs_proto;
use asynchronous_codec::{FramedRead, FramedWrite};
use futures::prelude::*;
use libp2p_core::{
upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
identity, multiaddr,
upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
Multiaddr, PublicKey,
};
use log::{debug, trace};
use prost::Message;
use log::trace;
use std::convert::TryFrom;
use std::{fmt, io, iter, pin::Pin};
use thiserror::Error;
const MAX_MESSAGE_SIZE_BYTES: usize = 4096;
/// Substream upgrade protocol for `/ipfs/id/1.0.0`.
#[derive(Debug, Clone)]
@ -89,8 +93,8 @@ where
///
/// Consumes the substream, returning a future that resolves
/// when the reply has been sent on the underlying connection.
pub async fn send(self, info: IdentifyInfo) -> io::Result<()> {
send(self.inner, info).await
pub async fn send(self, info: IdentifyInfo) -> Result<(), UpgradeError> {
send(self.inner, info).await.map_err(Into::into)
}
}
@ -105,8 +109,8 @@ impl UpgradeInfo for IdentifyProtocol {
impl<C> InboundUpgrade<C> for IdentifyProtocol {
type Output = ReplySubstream<C>;
type Error = io::Error;
type Future = future::Ready<Result<Self::Output, io::Error>>;
type Error = UpgradeError;
type Future = future::Ready<Result<Self::Output, UpgradeError>>;
fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future {
future::ok(ReplySubstream { inner: socket })
@ -118,7 +122,7 @@ where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = IdentifyInfo;
type Error = io::Error;
type Error = UpgradeError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future {
@ -140,7 +144,7 @@ where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = IdentifyInfo;
type Error = io::Error;
type Error = UpgradeError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future {
@ -153,7 +157,7 @@ where
C: AsyncWrite + Unpin + Send + 'static,
{
type Output = ();
type Error = io::Error;
type Error = UpgradeError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future {
@ -161,7 +165,7 @@ where
}
}
async fn send<T>(mut io: T, info: IdentifyInfo) -> io::Result<()>
async fn send<T>(io: T, info: IdentifyInfo) -> Result<(), UpgradeError>
where
T: AsyncWrite + Unpin,
{
@ -184,77 +188,99 @@ where
protocols: info.protocols,
};
let mut bytes = Vec::with_capacity(message.encoded_len());
message
.encode(&mut bytes)
.expect("Vec<u8> provides capacity as needed");
let mut framed_io = FramedWrite::new(
io,
prost_codec::Codec::<structs_proto::Identify>::new(MAX_MESSAGE_SIZE_BYTES),
);
upgrade::write_length_prefixed(&mut io, bytes).await?;
io.close().await?;
framed_io.send(message).await?;
framed_io.close().await?;
Ok(())
}
async fn recv<T>(mut socket: T) -> io::Result<IdentifyInfo>
async fn recv<T>(mut socket: T) -> Result<IdentifyInfo, UpgradeError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
socket.close().await?;
let msg = upgrade::read_length_prefixed(&mut socket, 4096)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
.await?;
let info = match parse_proto_msg(msg) {
Ok(v) => v,
Err(err) => {
debug!("Invalid message: {:?}", err);
return Err(err);
}
};
let info = FramedRead::new(
socket,
prost_codec::Codec::<structs_proto::Identify>::new(MAX_MESSAGE_SIZE_BYTES),
)
.next()
.await
.ok_or(UpgradeError::StreamClosed)??
.try_into()?;
trace!("Received: {:?}", info);
Ok(info)
}
/// Turns a protobuf message into an `IdentifyInfo`.
fn parse_proto_msg(msg: impl AsRef<[u8]>) -> Result<IdentifyInfo, io::Error> {
match structs_proto::Identify::decode(msg.as_ref()) {
Ok(msg) => {
fn parse_multiaddr(bytes: Vec<u8>) -> Result<Multiaddr, io::Error> {
Multiaddr::try_from(bytes)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
}
impl TryFrom<structs_proto::Identify> for IdentifyInfo {
type Error = UpgradeError;
let listen_addrs = {
let mut addrs = Vec::new();
for addr in msg.listen_addrs.into_iter() {
addrs.push(parse_multiaddr(addr)?);
}
addrs
};
let public_key = PublicKey::from_protobuf_encoding(&msg.public_key.unwrap_or_default())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let observed_addr = parse_multiaddr(msg.observed_addr.unwrap_or_default())?;
let info = IdentifyInfo {
public_key,
protocol_version: msg.protocol_version.unwrap_or_default(),
agent_version: msg.agent_version.unwrap_or_default(),
listen_addrs,
protocols: msg.protocols,
observed_addr,
};
Ok(info)
fn try_from(msg: structs_proto::Identify) -> Result<Self, Self::Error> {
fn parse_multiaddr(bytes: Vec<u8>) -> Result<Multiaddr, multiaddr::Error> {
Multiaddr::try_from(bytes)
}
Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)),
let listen_addrs = {
let mut addrs = Vec::new();
for addr in msg.listen_addrs.into_iter() {
addrs.push(parse_multiaddr(addr)?);
}
addrs
};
let public_key = PublicKey::from_protobuf_encoding(&msg.public_key.unwrap_or_default())?;
let observed_addr = parse_multiaddr(msg.observed_addr.unwrap_or_default())?;
let info = IdentifyInfo {
public_key,
protocol_version: msg.protocol_version.unwrap_or_default(),
agent_version: msg.agent_version.unwrap_or_default(),
listen_addrs,
protocols: msg.protocols,
observed_addr,
};
Ok(info)
}
}
#[derive(Debug, Error)]
pub enum UpgradeError {
#[error("Failed to encode or decode")]
Codec(
#[from]
#[source]
prost_codec::Error,
),
#[error("I/O interaction failed")]
Io(
#[from]
#[source]
io::Error,
),
#[error("Stream closed")]
StreamClosed,
#[error("Failed decoding multiaddr")]
Multiaddr(
#[from]
#[source]
multiaddr::Error,
),
#[error("Failed decoding public key")]
PublicKey(
#[from]
#[source]
identity::error::DecodingError,
),
}
#[cfg(test)]
mod tests {
use super::*;