diff --git a/libp2p-identify/Cargo.toml b/libp2p-identify/Cargo.toml index 200e2604..fdb97399 100644 --- a/libp2p-identify/Cargo.toml +++ b/libp2p-identify/Cargo.toml @@ -11,6 +11,7 @@ libp2p-swarm = { path = "../libp2p-swarm" } multiaddr = "0.2.0" protobuf = "1.4.2" tokio-io = "0.1.0" +varint = { path = "../varint-rs" } [dev-dependencies] libp2p-tcp-transport = { path = "../libp2p-tcp-transport" } diff --git a/libp2p-identify/src/lib.rs b/libp2p-identify/src/lib.rs index 37c832aa..6fb3d219 100644 --- a/libp2p-identify/src/lib.rs +++ b/libp2p-identify/src/lib.rs @@ -31,6 +31,7 @@ extern crate libp2p_peerstore; extern crate libp2p_swarm; extern crate protobuf; extern crate tokio_io; +extern crate varint; use bytes::{Bytes, BytesMut}; use futures::{Future, Stream, Sink}; @@ -42,7 +43,7 @@ use protobuf::repeated::RepeatedField; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::iter; use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::codec::length_delimited; +use varint::VarintCodec; mod structs_proto; @@ -94,8 +95,7 @@ impl ConnectionUpgrade for IdentifyProtocol } fn upgrade(self, socket: C, _: (), ty: Endpoint, remote_addr: &Multiaddr) -> Self::Future { - // TODO: use jack's varint library instead - let socket = length_delimited::Builder::new().length_field_length(1).new_framed(socket); + let socket = socket.framed(VarintCodec::default()); match ty { Endpoint::Dialer => { diff --git a/varint-rs/src/lib.rs b/varint-rs/src/lib.rs index 7c660bf6..5529c41c 100644 --- a/varint-rs/src/lib.rs +++ b/varint-rs/src/lib.rs @@ -31,14 +31,16 @@ extern crate futures; #[macro_use] extern crate error_chain; -use bytes::BytesMut; +use bytes::{BufMut, BytesMut, IntoBuf}; use futures::{Poll, Async}; use num_bigint::BigUint; use num_traits::ToPrimitive; use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::codec::Decoder; +use tokio_io::codec::{Encoder, Decoder}; use std::io; use std::io::prelude::*; +use std::marker::PhantomData; +use std::mem; mod errors { error_chain! { @@ -395,6 +397,78 @@ impl Decoder for VarintDecoder { } } +#[derive(Debug)] +pub struct VarintCodec { + inner: VarintCodecInner, + marker: PhantomData, +} + +impl Default for VarintCodec { + #[inline] + fn default() -> VarintCodec { + VarintCodec { + inner: VarintCodecInner::WaitingForLen(VarintDecoder::default()), + marker: PhantomData, + } + } +} + +#[derive(Debug)] +enum VarintCodecInner { + WaitingForLen(VarintDecoder), + WaitingForData(usize), + Poisonned, +} + +impl Decoder for VarintCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + loop { + match mem::replace(&mut self.inner, VarintCodecInner::Poisonned) { + VarintCodecInner::WaitingForData(len) => { + if src.len() >= len { + self.inner = VarintCodecInner::WaitingForLen(VarintDecoder::default()); + return Ok(Some(src.split_to(len))); + } else { + self.inner = VarintCodecInner::WaitingForData(len); + return Ok(None); + } + }, + VarintCodecInner::WaitingForLen(mut decoder) => { + match decoder.decode(src)? { + None => { + self.inner = VarintCodecInner::WaitingForLen(decoder); + return Ok(None); + }, + Some(len) => { + self.inner = VarintCodecInner::WaitingForData(len); + }, + } + }, + VarintCodecInner::Poisonned => { + panic!("varint codec was poisoned") + }, + } + } + } +} + +impl Encoder for VarintCodec + where D: IntoBuf + AsRef<[u8]>, +{ + type Item = D; + type Error = io::Error; + + fn encode(&mut self, item: D, dst: &mut BytesMut) -> Result<(), io::Error> { + let encoded_len = encode(item.as_ref().len()); // TODO: can be optimized by not allocating? + dst.put(encoded_len); + dst.put(item); + Ok(()) + } +} + /// Syncronously decode a number from a `Read` pub fn decode(mut input: R) -> errors::Result { let mut decoder = DecoderState::default();