diff --git a/protocols/secio/src/codec/decode.rs b/protocols/secio/src/codec/decode.rs index 2480222a..c0bd8967 100644 --- a/protocols/secio/src/codec/decode.rs +++ b/protocols/secio/src/codec/decode.rs @@ -29,6 +29,7 @@ use futures::stream::Stream; use futures::Async; use futures::Poll; use futures::StartSend; +use std::cmp::min; /// Wraps around a `Stream`. The buffers produced by the underlying stream /// are decoded using the cipher and hmac. @@ -42,19 +43,21 @@ pub struct DecoderMiddleware { cipher_state: StreamCipher, hmac: Hmac, raw_stream: S, + nonce: Vec } impl DecoderMiddleware { + /// Create a new decoder for the given stream, using the provided cipher and HMAC. + /// + /// The `nonce` parameter denotes a sequence of bytes which are expected to be found at the + /// beginning of the stream and are checked for equality. #[inline] - pub fn new( - raw_stream: S, - cipher: StreamCipher, - hmac: Hmac, - ) -> DecoderMiddleware { + pub fn new(raw_stream: S, cipher: StreamCipher, hmac: Hmac, nonce: Vec) -> DecoderMiddleware { DecoderMiddleware { cipher_state: cipher, hmac, raw_stream, + nonce } } } @@ -97,6 +100,15 @@ where .try_apply_keystream(&mut data_buf) .map_err::(|e|e.into())?; + if !self.nonce.is_empty() { + let n = min(data_buf.len(), self.nonce.len()); + if &data_buf[.. n] != &self.nonce[.. n] { + return Err(SecioError::NonceVerificationFailed) + } + self.nonce.drain(.. n); + data_buf.drain(.. n); + } + Ok(Async::Ready(Some(data_buf))) } } diff --git a/protocols/secio/src/codec/mod.rs b/protocols/secio/src/codec/mod.rs index a15d3c86..cb693185 100644 --- a/protocols/secio/src/codec/mod.rs +++ b/protocols/secio/src/codec/mod.rs @@ -109,12 +109,13 @@ pub fn full_codec( encoding_hmac: Hmac, cipher_decoder: StreamCipher, decoding_hmac: Hmac, + remote_nonce: Vec ) -> FullCodec where S: AsyncRead + AsyncWrite, { let encoder = EncoderMiddleware::new(socket, cipher_encoding, encoding_hmac); - DecoderMiddleware::new(encoder, cipher_decoder, decoding_hmac) + DecoderMiddleware::new(encoder, cipher_decoder, decoding_hmac, remote_nonce) } #[cfg(test)] @@ -133,7 +134,7 @@ mod tests { use bytes::BytesMut; use error::SecioError; use futures::sync::mpsc::channel; - use futures::{Future, Sink, Stream}; + use futures::{Future, Sink, Stream, stream}; use rand; use std::io::Error as IoError; use tokio_io::codec::length_delimited::Framed; @@ -159,6 +160,7 @@ mod tests { data_rx, ctr(Cipher::Aes256, &cipher_key, &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key), + Vec::new() ); let data = b"hello world"; @@ -181,20 +183,23 @@ mod tests { let hmac_key_clone = hmac_key.clone(); let data = b"hello world"; let data_clone = data.clone(); + let nonce = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); let listener_addr = listener.local_addr().unwrap(); - let server = listener.incoming().into_future().map_err(|(e, _)| e).map( - move |(connec, _)| { - let connec = Framed::new(connec.unwrap()); - + let nonce2 = nonce.clone(); + let server = listener.incoming() + .into_future() + .map_err(|(e, _)| e) + .map(move |(connec, _)| { full_codec( - connec, + Framed::new(connec.unwrap()), ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key), ctr(cipher, &cipher_key[..key_size], &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key), + nonce2 ) }, ); @@ -202,14 +207,13 @@ mod tests { let client = TcpStream::connect(&listener_addr) .map_err(|e| e.into()) .map(move |stream| { - let stream = Framed::new(stream); - full_codec( - stream, + Framed::new(stream), ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key_clone), ctr(cipher, &cipher_key_clone[..key_size], &NULL_IV[..]), Hmac::from_key(Digest::Sha256, &hmac_key_clone), + Vec::new() ) }); @@ -218,12 +222,11 @@ mod tests { .from_err::() .and_then(|(server, client)| { client - .send(BytesMut::from(&data_clone[..])) + .send_all(stream::iter_ok::<_, IoError>(vec![nonce.into(), data_clone[..].into()])) .map(move |_| server) .from_err() }) - .and_then(|server| server.into_future().map_err(|(e, _)| e.into())) - .map(|recved| recved.0.unwrap().to_vec()); + .and_then(|server| server.concat2().from_err()); let mut rt = Runtime::new().unwrap(); let received = rt.block_on(fin).unwrap(); diff --git a/protocols/secio/src/handshake.rs b/protocols/secio/src/handshake.rs index 43807ee9..18e9e4ee 100644 --- a/protocols/secio/src/handshake.rs +++ b/protocols/secio/src/handshake.rs @@ -307,10 +307,8 @@ impl HandshakeContext { /// On success, returns an object that implements the `Sink` and `Stream` trait whose items are /// buffers of data, plus the public key of the remote, plus the ephemeral public key used during /// negotiation. -pub fn handshake<'a, S: 'a>( - socket: S, - config: SecioConfig -) -> Box, PublicKey, Vec), Error = SecioError> + Send + 'a> +pub fn handshake<'a, S: 'a>(socket: S, config: SecioConfig) + -> impl Future, PublicKey, Vec), Error = SecioError> where S: AsyncRead + AsyncWrite + Send, { @@ -320,7 +318,7 @@ where .length_field_length(4) .new_framed(socket); - let future = future::ok::<_, SecioError>(HandshakeContext::new(config)) + future::ok::<_, SecioError>(HandshakeContext::new(config)) .and_then(|context| { // Generate our nonce. let context = context.with_local()?; @@ -570,7 +568,14 @@ where (cipher, hmac) }; - let codec = full_codec(socket, encoding_cipher, encoding_hmac, decoding_cipher, decoding_hmac); + let codec = full_codec( + socket, + encoding_cipher, + encoding_hmac, + decoding_cipher, + decoding_hmac, + context.state.remote.local.nonce.to_vec() + ); Ok((codec, context)) }) // We send back their nonce to check if the connection works. @@ -578,32 +583,9 @@ where let remote_nonce = context.state.remote.nonce.clone(); trace!("checking encryption by sending back remote's nonce"); codec.send(BytesMut::from(remote_nonce)) - .map(|s| (s, context)) + .map(|s| (s, context.state.remote.public_key, context.state.local_tmp_pub_key)) .from_err() }) - // Check that the received nonce is correct. - .and_then(|(codec, context)| { - codec.into_future() - .map_err(|(e, _)| e) - .and_then(move |(nonce, rest)| { - match nonce { - Some(ref n) if n == &context.state.remote.local.nonce => { - trace!("secio handshake success"); - Ok((rest, context.state.remote.public_key, context.state.local_tmp_pub_key)) - }, - None => { - debug!("unexpected eof during nonce check"); - Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof").into()) - }, - _ => { - debug!("failed nonce verification with remote"); - Err(SecioError::NonceVerificationFailed) - } - } - }) - }); - - Box::new(future) } /// Custom algorithm translated from reference implementations. Needs to be the same algorithm @@ -647,15 +629,16 @@ where ::hmac::Hmac: Clone { mod tests { extern crate tokio; extern crate tokio_tcp; + use bytes::BytesMut; use self::tokio::runtime::current_thread::Runtime; use self::tokio_tcp::TcpListener; use self::tokio_tcp::TcpStream; + use crate::SecioError; use super::handshake; use super::stretch_key; use algo_support::Digest; use codec::Hmac; - use futures::Future; - use futures::Stream; + use futures::prelude::*; use {SecioConfig, SecioKeyPair}; #[test] @@ -707,11 +690,29 @@ mod tests { .incoming() .into_future() .map_err(|(e, _)| e.into()) - .and_then(move |(connec, _)| handshake(connec.unwrap(), key1)); + .and_then(move |(connec, _)| handshake(connec.unwrap(), key1)) + .and_then(|(connec, _, _)| { + let (sink, stream) = connec.split(); + stream + .filter(|v| !v.is_empty()) + .forward(sink.with(|v| Ok::<_, SecioError>(BytesMut::from(v)))) + }); let client = TcpStream::connect(&listener_addr) .map_err(|e| e.into()) - .and_then(move |stream| handshake(stream, key2)); + .and_then(move |stream| handshake(stream, key2)) + .and_then(|(connec, _, _)| { + connec.send("hello".into()) + .from_err() + .and_then(|connec| { + connec.filter(|v| !v.is_empty()) + .into_future() + .map(|(v, _)| v) + .map_err(|(e, _)| e) + }) + .map(|v| assert_eq!(b"hello", &v.unwrap()[..])) + }); + let mut rt = Runtime::new().unwrap(); let _ = rt.block_on(server.join(client)).unwrap(); } diff --git a/protocols/secio/src/lib.rs b/protocols/secio/src/lib.rs index 9d768c13..9951ce15 100644 --- a/protocols/secio/src/lib.rs +++ b/protocols/secio/src/lib.rs @@ -427,19 +427,13 @@ where /// /// On success, produces a `SecioMiddleware` that can then be used to encode/decode /// communications, plus the public key of the remote, plus the ephemeral public key. - pub fn handshake<'a>( - socket: S, - config: SecioConfig, - ) -> Box, PublicKey, Vec), Error = SecioError> + Send + 'a> - where - S: 'a, + pub fn handshake(socket: S, config: SecioConfig) + -> impl Future, PublicKey, Vec), Error = SecioError> { - let fut = handshake::handshake(socket, config).map(|(inner, pubkey, ephemeral)| { + handshake::handshake(socket, config).map(|(inner, pubkey, ephemeral)| { let inner = SecioMiddleware { inner }; (inner, pubkey, ephemeral) - }); - - Box::new(fut) + }) } }