diff --git a/protocols/secio/Cargo.toml b/protocols/secio/Cargo.toml index 4023a0b9..0dd7fdf9 100644 --- a/protocols/secio/Cargo.toml +++ b/protocols/secio/Cargo.toml @@ -12,20 +12,19 @@ categories = ["network-programming", "asynchronous"] [dependencies] aes-ctr = "0.3" aesni = { version = "0.6", features = ["nocheck"], optional = true } -bytes = "0.4.12" ctr = "0.3" futures = "0.3.1" -futures_codec = "0.3.1" hmac = "0.7.0" lazy_static = "1.2.0" libp2p-core = { version = "0.13.0", path = "../../core" } log = "0.4.6" protobuf = "2.8" -rand = "0.6.5" +quicksink = { git = "https://github.com/paritytech/quicksink.git" } +rand = "0.7" rw-stream-sink = { version = "0.1.1", path = "../../misc/rw-stream-sink" } sha2 = "0.8.0" +static_assertions = "1" twofish = "0.2.0" -unsigned-varint = { version = "0.2.3", features = ["futures-codec"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] ring = { version = "0.16.9", features = ["alloc"], default-features = false } diff --git a/protocols/secio/src/codec/decode.rs b/protocols/secio/src/codec/decode.rs index 7a80bec0..14edb8ef 100644 --- a/protocols/secio/src/codec/decode.rs +++ b/protocols/secio/src/codec/decode.rs @@ -59,7 +59,7 @@ impl DecoderMiddleware { impl Stream for DecoderMiddleware where - S: TryStream + Unpin, + S: TryStream> + Unpin, S::Error: Into, { type Item = Result, SecioError>; @@ -87,10 +87,9 @@ where } } - let mut data_buf = frame.to_vec(); + let mut data_buf = frame; data_buf.truncate(content_length); - self.cipher_state - .decrypt(&mut data_buf); + self.cipher_state.decrypt(&mut data_buf); if !self.nonce.is_empty() { let n = min(data_buf.len(), self.nonce.len()); diff --git a/protocols/secio/src/codec/len_prefix.rs b/protocols/secio/src/codec/len_prefix.rs new file mode 100644 index 00000000..376d15c2 --- /dev/null +++ b/protocols/secio/src/codec/len_prefix.rs @@ -0,0 +1,124 @@ +// Copyright 2019 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use futures::{prelude::*, stream::BoxStream}; +use quicksink::Action; +use std::{fmt, io, pin::Pin, task::{Context, Poll}}; + +/// `Stream` & `Sink` that reads and writes a length prefix in front of the actual data. +pub struct LenPrefixCodec { + stream: BoxStream<'static, io::Result>>, + sink: Pin, Error = io::Error> + Send>>, + _mark: std::marker::PhantomData +} + +impl fmt::Debug for LenPrefixCodec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("LenPrefixCodec") + } +} + +static_assertions::const_assert! { + std::mem::size_of::() <= std::mem::size_of::() +} + +impl LenPrefixCodec +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static +{ + pub fn new(socket: T, max_len: usize) -> Self { + let (r, w) = socket.split(); + + let stream = futures::stream::unfold(r, move |mut r| async move { + let mut len = [0; 4]; + if let Err(e) = r.read_exact(&mut len).await { + if e.kind() == io::ErrorKind::UnexpectedEof { + return None + } + return Some((Err(e), r)) + } + let n = u32::from_be_bytes(len) as usize; + if n > max_len { + let msg = format!("data length {} exceeds allowed maximum {}", n, max_len); + return Some((Err(io::Error::new(io::ErrorKind::PermissionDenied, msg)), r)) + } + let mut v = vec![0; n]; + if let Err(e) = r.read_exact(&mut v).await { + return Some((Err(e), r)) + } + Some((Ok(v), r)) + }); + + let sink = quicksink::make_sink(w, move |mut w, action: Action>| async move { + match action { + Action::Send(data) => { + if data.len() > max_len { + log::error!("data length {} exceeds allowed maximum {}", data.len(), max_len) + } + w.write_all(&(data.len() as u32).to_be_bytes()).await?; + w.write_all(&data).await? + } + Action::Flush => w.flush().await?, + Action::Close => w.close().await? + } + Ok(w) + }); + + LenPrefixCodec { + stream: stream.boxed(), + sink: Box::pin(sink), + _mark: std::marker::PhantomData + } + } +} + +impl Stream for LenPrefixCodec +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static +{ + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.stream.poll_next_unpin(cx) + } +} + +impl Sink> for LenPrefixCodec +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static +{ + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.sink).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + Pin::new(&mut self.sink).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.sink).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.sink).poll_close(cx) + } +} diff --git a/protocols/secio/src/codec/mod.rs b/protocols/secio/src/codec/mod.rs index e02bd00b..5e8ec83a 100644 --- a/protocols/secio/src/codec/mod.rs +++ b/protocols/secio/src/codec/mod.rs @@ -21,21 +21,22 @@ //! Individual messages encoding and decoding. Use this after the algorithms have been //! successfully negotiated. -use self::decode::DecoderMiddleware; -use self::encode::EncoderMiddleware; - -use crate::algo_support::Digest; -use futures::prelude::*; -use aes_ctr::stream_cipher; -use hmac::{self, Mac}; -use sha2::{Sha256, Sha512}; -use unsigned_varint::codec::UviBytes; - mod decode; mod encode; +mod len_prefix; + +use aes_ctr::stream_cipher; +use crate::algo_support::Digest; +use decode::DecoderMiddleware; +use encode::EncoderMiddleware; +use futures::prelude::*; +use hmac::{self, Mac}; +use sha2::{Sha256, Sha512}; + +pub use len_prefix::LenPrefixCodec; /// Type returned by `full_codec`. -pub type FullCodec = DecoderMiddleware>>>>; +pub type FullCodec = DecoderMiddleware>>; pub type StreamCipher = Box; @@ -108,7 +109,7 @@ impl Hmac { /// The conversion between the stream/sink items and the socket is done with the given cipher and /// hash algorithm (which are generally decided during the handshake). pub fn full_codec( - socket: futures_codec::Framed>>, + socket: LenPrefixCodec, cipher_encoding: StreamCipher, encoding_hmac: Hmac, cipher_decoder: StreamCipher, @@ -116,30 +117,27 @@ pub fn full_codec( remote_nonce: Vec ) -> FullCodec where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static { let encoder = EncoderMiddleware::new(socket, cipher_encoding, encoding_hmac); DecoderMiddleware::new(encoder, cipher_decoder, decoding_hmac, remote_nonce) } + #[cfg(test)] mod tests { - use super::{full_codec, DecoderMiddleware, EncoderMiddleware, Hmac}; + use super::{full_codec, DecoderMiddleware, EncoderMiddleware, Hmac, LenPrefixCodec}; use crate::algo_support::Digest; use crate::stream_cipher::{ctr, Cipher}; use crate::error::SecioError; use async_std::net::{TcpListener, TcpStream}; - use bytes::BytesMut; use futures::{prelude::*, channel::mpsc, channel::oneshot}; - use futures_codec::Framed; - use unsigned_varint::codec::UviBytes; const NULL_IV : [u8; 16] = [0; 16]; #[test] fn raw_encode_then_decode() { let (data_tx, data_rx) = mpsc::channel::>(256); - let data_rx = data_rx.map(BytesMut::from); let cipher_key: [u8; 32] = rand::random(); let hmac_key: [u8; 32] = rand::random(); @@ -184,7 +182,7 @@ mod tests { let (connec, _) = listener.accept().await.unwrap(); let codec = full_codec( - Framed::new(connec, UviBytes::default()), + LenPrefixCodec::new(connec, 1024), ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key), ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), @@ -200,7 +198,7 @@ mod tests { let listener_addr = l_a_rx.await.unwrap(); let stream = TcpStream::connect(&listener_addr).await.unwrap(); let mut codec = full_codec( - Framed::new(stream, UviBytes::default()), + LenPrefixCodec::new(stream, 1024), ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key_clone), ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), diff --git a/protocols/secio/src/handshake.rs b/protocols/secio/src/handshake.rs index b90ea93a..26dff527 100644 --- a/protocols/secio/src/handshake.rs +++ b/protocols/secio/src/handshake.rs @@ -18,22 +18,23 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::SecioConfig; use crate::algo_support; -use crate::codec::{full_codec, FullCodec, Hmac}; -use crate::stream_cipher::ctr; +use crate::codec::{full_codec, FullCodec, Hmac, LenPrefixCodec}; use crate::error::SecioError; use crate::exchange; +use crate::stream_cipher::ctr; +use crate::structs_proto::{Exchange, Propose}; use futures::prelude::*; use libp2p_core::PublicKey; use log::{debug, trace}; -use protobuf::parse_from_bytes as protobuf_parse_from_bytes; use protobuf::Message as ProtobufMessage; +use protobuf::parse_from_bytes as protobuf_parse_from_bytes; use rand::{self, RngCore}; use sha2::{Digest as ShaDigestTrait, Sha256}; use std::cmp::{self, Ordering}; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use crate::structs_proto::{Exchange, Propose}; -use crate::SecioConfig; + /// Performs a handshake on the given socket. /// @@ -44,16 +45,12 @@ use crate::SecioConfig; /// On success, returns an object that implements the `Sink` and `Stream` trait whose items are /// buffers of data, plus the public key of the remote, plus the ephemeral public key used during /// negotiation. -pub async fn handshake<'a, S: 'a>(socket: S, config: SecioConfig) +pub async fn handshake(socket: S, config: SecioConfig) -> Result<(FullCodec, PublicKey, Vec), SecioError> where - S: AsyncRead + AsyncWrite + Send + Unpin, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static { - // The handshake messages all start with a variable-length integer indicating the size. - let mut socket = futures_codec::Framed::new( - socket, - unsigned_varint::codec::UviBytes::>::default() - ); + let mut socket = LenPrefixCodec::new(socket, config.max_frame_len); let local_nonce = { let mut local_nonce = [0; 16]; diff --git a/protocols/secio/src/lib.rs b/protocols/secio/src/lib.rs index 205198d9..af55a279 100644 --- a/protocols/secio/src/lib.rs +++ b/protocols/secio/src/lib.rs @@ -85,7 +85,8 @@ pub struct SecioConfig { pub(crate) key: identity::Keypair, pub(crate) agreements_prop: Option, pub(crate) ciphers_prop: Option, - pub(crate) digests_prop: Option + pub(crate) digests_prop: Option, + pub(crate) max_frame_len: usize } impl SecioConfig { @@ -95,7 +96,8 @@ impl SecioConfig { key: kp, agreements_prop: None, ciphers_prop: None, - digests_prop: None + digests_prop: None, + max_frame_len: 8 * 1024 * 1024 } } @@ -126,6 +128,12 @@ impl SecioConfig { self } + /// Override the default max. frame length of 8MiB. + pub fn max_frame_len(mut self, n: usize) -> Self { + self.max_frame_len = n; + self + } + fn handshake(self, socket: T) -> impl Future), SecioError>> where T: AsyncRead + AsyncWrite + Unpin + Send + 'static @@ -148,7 +156,7 @@ impl SecioConfig { /// Output of the secio protocol. pub struct SecioOutput where - S: AsyncRead + AsyncWrite + Unpin + S: AsyncRead + AsyncWrite + Unpin + Send + 'static { /// The encrypted stream. pub stream: RwStreamSink, fn(SecioError) -> io::Error>>, @@ -193,7 +201,10 @@ where } } -impl AsyncRead for SecioOutput { +impl AsyncRead for SecioOutput +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static +{ fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { @@ -201,7 +212,10 @@ impl AsyncRead for SecioOutput { } } -impl AsyncWrite for SecioOutput { +impl AsyncWrite for SecioOutput +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static +{ fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { @@ -254,7 +268,7 @@ where impl Sink> for SecioMiddleware where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static { type Error = io::Error; @@ -277,7 +291,7 @@ where impl Stream for SecioMiddleware where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static { type Item = Result, SecioError>;