protocols/noise: Update to futures-preview (#1248)

* protocols/noise: Fix obvious future errors

* protocol/noise: Make Handshake methods independent functions

* protocols/noise: Abstract T and C for handshake

* protocols/noise: Replace FutureResult with Result

* protocols/noise: Introduce recv_identity stub

* protocols/noise: Implement recv_identity stub

* protocols/noise: Change NoiseConfig::Future from Handshake to Result

* protocols/noise: Adjust to new Poll syntax

* protocols/noise: Return early on state creation failure

* protocols/noise: Add bounds Async{Write,Read} to initiator / respoder

* protocols/noise: Add Protocol trait bound for C in rt functions

* protocols/noise: Do io operations on state.io instead of state

* protocols/noise: Have upgrade_xxx return a pinned future

* protocols/noise: Have NoiseOutput::poll_read self be mutable

* protocols/noise: Make recv_identity buffers mutable

* protocols/noise: Fix warnings

* protocols/noise: Replace NoiseOutput io::Read impl with AsyncRead

* protocols/noise: Replace NoiseOutput io::Write impl with AsyncWrite

* protocols/noise: Adjust tests to new futures

* protocols/noise: Don't use {AsyncRead,AsyncWrite,TryStream}*Ext* bound

* protocols/noise: Don't use async_closure feature

* protocols/noise: use futures::ready! macro

* protocols/noise: Make NoiseOutput AsyncRead return unsafe NopInitializer

The previous implementation of AsyncRead for NoiseOutput would operate
on uninitialized buffers, given that it properly returned the number of
bytest that were written to the buffer. With this patch the current
implementation operates on uninitialized buffers as well by returning an
Initializer::nop() in AsyncRead::initializer.

* protocols/noise: Remove resolved TODO questions

* protocols/noise: Remove 'this = self' comment

Given that `let mut this = &mut *self` is not specific to a pinned self,
but follows the dereference coercion [1] happening at compile time when
trying to mutably borrow two distinct struct fields, this patch removes
the code comment.

[1]
```rust
let x = &mut self.deref_mut().x;
let y = &mut self.deref_mut().y; // error

// ---

let mut this = self.deref_mut();
let x = &mut this.x;
let y = &mut this.y; // ok
```

* Remove redundant nested futures.

* protocols/noise/Cargo: Update to futures preview 0.3.0-alpha.18

* protocols/noise: Improve formatting

* protocols/noise: Return pinned future on authenticated noise upgrade

* protocols/noise: Specify Output of Future embedded in Handshake directly

* *: Ensure Noise handshake futures are Send

* Revert "*: Ensure Noise handshake futures are Send"

This reverts commit 555c2df315e44f21ad39d4408445ce2cb84dd1a4.

* protocols/noise: Ensure NoiseConfig Future is Send

* protocols/noise: Use relative import path for {In,Out}boundUpgrade
This commit is contained in:
Max Inden
2019-10-03 23:40:14 +02:00
committed by Roman Borschel
parent 7f5868472d
commit 73aa27827f
4 changed files with 464 additions and 410 deletions

View File

@ -22,11 +22,11 @@
pub mod handshake; pub mod handshake;
use futures::Poll; use futures::{ready, Poll};
use futures::prelude::*;
use log::{debug, trace}; use log::{debug, trace};
use snow; use snow;
use std::{fmt, io}; use std::{fmt, io, pin::Pin, ops::DerefMut, task::Context};
use tokio_io::{AsyncRead, AsyncWrite};
const MAX_NOISE_PKG_LEN: usize = 65535; const MAX_NOISE_PKG_LEN: usize = 65535;
const MAX_WRITE_BUF_LEN: usize = 16384; const MAX_WRITE_BUF_LEN: usize = 16384;
@ -121,57 +121,75 @@ enum WriteState {
EncErr EncErr
} }
impl<T: io::Read> io::Read for NoiseOutput<T> { impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn poll_read(
let buffer = self.buffer.borrow_mut(); mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, std::io::Error>> {
let mut this = self.deref_mut();
let buffer = this.buffer.borrow_mut();
loop { loop {
trace!("read state: {:?}", self.read_state); trace!("read state: {:?}", this.read_state);
match self.read_state { match this.read_state {
ReadState::Init => { ReadState::Init => {
self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; this.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 };
} }
ReadState::ReadLen { mut buf, mut off } => { ReadState::ReadLen { mut buf, mut off } => {
let n = match read_frame_len(&mut self.io, &mut buf, &mut off) { let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) {
Ok(Some(n)) => n, Poll::Ready(Ok(Some(n))) => n,
Ok(None) => { Poll::Ready(Ok(None)) => {
trace!("read: eof"); trace!("read: eof");
self.read_state = ReadState::Eof(Ok(())); this.read_state = ReadState::Eof(Ok(()));
return Ok(0) return Poll::Ready(Ok(0))
} }
Err(e) => { Poll::Ready(Err(e)) => {
if e.kind() == io::ErrorKind::WouldBlock { return Poll::Ready(Err(e))
// Preserve read state
self.read_state = ReadState::ReadLen { buf, off };
} }
return Err(e) Poll::Pending => {
this.read_state = ReadState::ReadLen { buf, off };
return Poll::Pending;
} }
}; };
trace!("read: next frame len = {}", n); trace!("read: next frame len = {}", n);
if n == 0 { if n == 0 {
trace!("read: empty frame"); trace!("read: empty frame");
self.read_state = ReadState::Init; this.read_state = ReadState::Init;
continue continue
} }
self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 }
} }
ReadState::ReadData { len, ref mut off } => { ReadState::ReadData { len, ref mut off } => {
let n = self.io.read(&mut buffer.read[*off .. len])?; let n = match ready!(
Pin::new(&mut this.io).poll_read(cx, &mut buffer.read[*off ..len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
};
trace!("read: read {}/{} bytes", *off + n, len); trace!("read: read {}/{} bytes", *off + n, len);
if n == 0 { if n == 0 {
trace!("read: eof"); trace!("read: eof");
self.read_state = ReadState::Eof(Err(())); this.read_state = ReadState::Eof(Err(()));
return Err(io::ErrorKind::UnexpectedEof.into()) return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
} }
*off += n; *off += n;
if len == *off { if len == *off {
trace!("read: decrypting {} bytes", len); trace!("read: decrypting {} bytes", len);
if let Ok(n) = self.session.read_message(&buffer.read[.. len], buffer.read_crypto) { if let Ok(n) = this.session.read_message(
&buffer.read[.. len],
buffer.read_crypto
){
trace!("read: payload len = {} bytes", n); trace!("read: payload len = {} bytes", n);
self.read_state = ReadState::CopyData { len: n, off: 0 } this.read_state = ReadState::CopyData { len: n, off: 0 }
} else { } else {
debug!("decryption error"); debug!("decryption error");
self.read_state = ReadState::DecErr; this.read_state = ReadState::DecErr;
return Err(io::ErrorKind::InvalidData.into()) return Poll::Ready(Err(io::ErrorKind::InvalidData.into()))
} }
} }
} }
@ -181,32 +199,43 @@ impl<T: io::Read> io::Read for NoiseOutput<T> {
trace!("read: copied {}/{} bytes", *off + n, len); trace!("read: copied {}/{} bytes", *off + n, len);
*off += n; *off += n;
if len == *off { if len == *off {
self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; this.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 };
} }
return Ok(n) return Poll::Ready(Ok(n))
} }
ReadState::Eof(Ok(())) => { ReadState::Eof(Ok(())) => {
trace!("read: eof"); trace!("read: eof");
return Ok(0) return Poll::Ready(Ok(0))
} }
ReadState::Eof(Err(())) => { ReadState::Eof(Err(())) => {
trace!("read: eof (unexpected)"); trace!("read: eof (unexpected)");
return Err(io::ErrorKind::UnexpectedEof.into()) return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
} }
ReadState::DecErr => return Err(io::ErrorKind::InvalidData.into()) ReadState::DecErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into()))
} }
} }
} }
unsafe fn initializer(&self) -> futures::io::Initializer {
futures::io::Initializer::nop()
}
} }
impl<T: io::Write> io::Write for NoiseOutput<T> { impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_write(
let buffer = self.buffer.borrow_mut(); mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>>{
let mut this = self.deref_mut();
let buffer = this.buffer.borrow_mut();
loop { loop {
trace!("write state: {:?}", self.write_state); trace!("write state: {:?}", this.write_state);
match self.write_state { match this.write_state {
WriteState::Init => { WriteState::Init => {
self.write_state = WriteState::BufferData { off: 0 } this.write_state = WriteState::BufferData { off: 0 }
} }
WriteState::BufferData { ref mut off } => { WriteState::BufferData { ref mut off } => {
let n = std::cmp::min(MAX_WRITE_BUF_LEN - *off, buf.len()); let n = std::cmp::min(MAX_WRITE_BUF_LEN - *off, buf.len());
@ -215,136 +244,157 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
*off += n; *off += n;
if *off == MAX_WRITE_BUF_LEN { if *off == MAX_WRITE_BUF_LEN {
trace!("write: encrypting {} bytes", *off); trace!("write: encrypting {} bytes", *off);
if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) { match this.session.write_message(buffer.write, buffer.write_crypto) {
Ok(n) => {
trace!("write: cipher text len = {} bytes", n); trace!("write: cipher text len = {} bytes", n);
self.write_state = WriteState::WriteLen { this.write_state = WriteState::WriteLen {
len: n, len: n,
buf: u16::to_be_bytes(n as u16), buf: u16::to_be_bytes(n as u16),
off: 0 off: 0
} }
} else { }
debug!("encryption error"); Err(e) => {
self.write_state = WriteState::EncErr; debug!("encryption error: {:?}", e);
return Err(io::ErrorKind::InvalidData.into()) this.write_state = WriteState::EncErr;
return Poll::Ready(Err(io::ErrorKind::InvalidData.into()))
} }
} }
return Ok(n) }
return Poll::Ready(Ok(n))
} }
WriteState::WriteLen { len, mut buf, mut off } => { WriteState::WriteLen { len, mut buf, mut off } => {
trace!("write: writing len ({}, {:?}, {}/2)", len, buf, off); trace!("write: writing len ({}, {:?}, {}/2)", len, buf, off);
match write_frame_len(&mut self.io, &mut buf, &mut off) { match write_frame_len(&mut this.io, cx, &mut buf, &mut off) {
Err(e) => { Poll::Ready(Ok(true)) => (),
if e.kind() == io::ErrorKind::WouldBlock { Poll::Ready(Ok(false)) => {
self.write_state = WriteState::WriteLen{ len, buf, off };
}
return Err(e)
}
Ok(false) => {
trace!("write: eof"); trace!("write: eof");
self.write_state = WriteState::Eof; this.write_state = WriteState::Eof;
return Err(io::ErrorKind::WriteZero.into()) return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
} }
Ok(true) => () Poll::Ready(Err(e)) => {
return Poll::Ready(Err(e))
} }
self.write_state = WriteState::WriteData { len, off: 0 } Poll::Pending => {
this.write_state = WriteState::WriteLen{ len, buf, off };
return Poll::Pending
}
}
this.write_state = WriteState::WriteData { len, off: 0 }
} }
WriteState::WriteData { len, ref mut off } => { WriteState::WriteData { len, ref mut off } => {
let n = self.io.write(&buffer.write_crypto[*off .. len])?; let n = match ready!(
Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
};
trace!("write: wrote {}/{} bytes", *off + n, len); trace!("write: wrote {}/{} bytes", *off + n, len);
if n == 0 { if n == 0 {
trace!("write: eof"); trace!("write: eof");
self.write_state = WriteState::Eof; this.write_state = WriteState::Eof;
return Err(io::ErrorKind::WriteZero.into()) return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
} }
*off += n; *off += n;
if len == *off { if len == *off {
trace!("write: finished writing {} bytes", len); trace!("write: finished writing {} bytes", len);
self.write_state = WriteState::Init this.write_state = WriteState::Init
} }
} }
WriteState::Eof => { WriteState::Eof => {
trace!("write: eof"); trace!("write: eof");
return Err(io::ErrorKind::WriteZero.into()) return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
} }
WriteState::EncErr => return Err(io::ErrorKind::InvalidData.into()) WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into()))
} }
} }
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(
let buffer = self.buffer.borrow_mut(); mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), std::io::Error>> {
let mut this = self.deref_mut();
let buffer = this.buffer.borrow_mut();
loop { loop {
match self.write_state { match this.write_state {
WriteState::Init => return Ok(()), WriteState::Init => return Poll::Ready(Ok(())),
WriteState::BufferData { off } => { WriteState::BufferData { off } => {
trace!("flush: encrypting {} bytes", off); trace!("flush: encrypting {} bytes", off);
if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) { match this.session.write_message(&buffer.write[.. off], buffer.write_crypto) {
Ok(n) => {
trace!("flush: cipher text len = {} bytes", n); trace!("flush: cipher text len = {} bytes", n);
self.write_state = WriteState::WriteLen { this.write_state = WriteState::WriteLen {
len: n, len: n,
buf: u16::to_be_bytes(n as u16), buf: u16::to_be_bytes(n as u16),
off: 0 off: 0
} }
} else { }
debug!("encryption error"); Err(e) => {
self.write_state = WriteState::EncErr; debug!("encryption error: {:?}", e);
return Err(io::ErrorKind::InvalidData.into()) this.write_state = WriteState::EncErr;
return Poll::Ready(Err(io::ErrorKind::InvalidData.into()))
}
} }
} }
WriteState::WriteLen { len, mut buf, mut off } => { WriteState::WriteLen { len, mut buf, mut off } => {
trace!("flush: writing len ({}, {:?}, {}/2)", len, buf, off); trace!("flush: writing len ({}, {:?}, {}/2)", len, buf, off);
match write_frame_len(&mut self.io, &mut buf, &mut off) { match write_frame_len(&mut this.io, cx, &mut buf, &mut off) {
Ok(true) => (), Poll::Ready(Ok(true)) => (),
Ok(false) => { Poll::Ready(Ok(false)) => {
trace!("write: eof"); trace!("write: eof");
self.write_state = WriteState::Eof; this.write_state = WriteState::Eof;
return Err(io::ErrorKind::WriteZero.into()) return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
} }
Err(e) => { Poll::Ready(Err(e)) => {
if e.kind() == io::ErrorKind::WouldBlock { return Poll::Ready(Err(e))
// Preserve write state
self.write_state = WriteState::WriteLen { len, buf, off };
} }
return Err(e) Poll::Pending => {
this.write_state = WriteState::WriteLen { len, buf, off };
return Poll::Pending
} }
} }
self.write_state = WriteState::WriteData { len, off: 0 } this.write_state = WriteState::WriteData { len, off: 0 }
} }
WriteState::WriteData { len, ref mut off } => { WriteState::WriteData { len, ref mut off } => {
let n = self.io.write(&buffer.write_crypto[*off .. len])?; let n = match ready!(
Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
};
trace!("flush: wrote {}/{} bytes", *off + n, len); trace!("flush: wrote {}/{} bytes", *off + n, len);
if n == 0 { if n == 0 {
trace!("flush: eof"); trace!("flush: eof");
self.write_state = WriteState::Eof; this.write_state = WriteState::Eof;
return Err(io::ErrorKind::WriteZero.into()) return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
} }
*off += n; *off += n;
if len == *off { if len == *off {
trace!("flush: finished writing {} bytes", len); trace!("flush: finished writing {} bytes", len);
self.write_state = WriteState::Init; this.write_state = WriteState::Init;
return Ok(()) return Poll::Ready(Ok(()))
} }
} }
WriteState::Eof => { WriteState::Eof => {
trace!("flush: eof"); trace!("flush: eof");
return Err(io::ErrorKind::WriteZero.into()) return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
} }
WriteState::EncErr => return Err(io::ErrorKind::InvalidData.into()) WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into()))
} }
} }
} }
}
impl<T: AsyncRead> AsyncRead for NoiseOutput<T> { fn poll_close(
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { mut self: Pin<&mut Self>,
false cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>>{
Pin::new(&mut self.io).poll_close(cx)
} }
}
impl<T: AsyncWrite> AsyncWrite for NoiseOutput<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.io.shutdown()
}
} }
/// Read 2 bytes as frame length from the given source into the given buffer. /// Read 2 bytes as frame length from the given source into the given buffer.
@ -356,17 +406,26 @@ impl<T: AsyncWrite> AsyncWrite for NoiseOutput<T> {
/// for the next invocation. /// for the next invocation.
/// ///
/// Returns `None` if EOF has been encountered. /// Returns `None` if EOF has been encountered.
fn read_frame_len<R: io::Read>(io: &mut R, buf: &mut [u8; 2], off: &mut usize) fn read_frame_len<R: AsyncRead + Unpin>(
-> io::Result<Option<u16>> mut io: &mut R,
{ cx: &mut Context<'_>,
buf: &mut [u8; 2],
off: &mut usize,
) -> Poll<Result<Option<u16>, std::io::Error>> {
loop { loop {
let n = io.read(&mut buf[*off ..])?; match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) {
Ok(n) => {
if n == 0 { if n == 0 {
return Ok(None) return Poll::Ready(Ok(None));
} }
*off += n; *off += n;
if *off == 2 { if *off == 2 {
return Ok(Some(u16::from_be_bytes(*buf))) return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf))));
}
},
Err(e) => {
return Poll::Ready(Err(e));
},
} }
} }
} }
@ -380,18 +439,26 @@ fn read_frame_len<R: io::Read>(io: &mut R, buf: &mut [u8; 2], off: &mut usize)
/// be preserved for the next invocation. /// be preserved for the next invocation.
/// ///
/// Returns `false` if EOF has been encountered. /// Returns `false` if EOF has been encountered.
fn write_frame_len<W: io::Write>(io: &mut W, buf: &[u8; 2], off: &mut usize) fn write_frame_len<W: AsyncWrite + Unpin>(
-> io::Result<bool> mut io: &mut W,
{ cx: &mut Context<'_>,
buf: &[u8; 2],
off: &mut usize,
) -> Poll<Result<bool, std::io::Error>> {
loop { loop {
let n = io.write(&buf[*off ..])?; match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) {
Ok(n) => {
if n == 0 { if n == 0 {
return Ok(false) return Poll::Ready(Ok(false))
} }
*off += n; *off += n;
if *off == 2 { if *off == 2 {
return Ok(true) return Poll::Ready(Ok(true))
}
}
Err(e) => {
return Poll::Ready(Err(e));
}
} }
} }
} }

View File

@ -26,9 +26,10 @@ use crate::error::NoiseError;
use crate::protocol::{Protocol, PublicKey, KeypairIdentity}; use crate::protocol::{Protocol, PublicKey, KeypairIdentity};
use libp2p_core::identity; use libp2p_core::identity;
use futures::prelude::*; use futures::prelude::*;
use std::{mem, io, task::Poll}; use futures::task;
use futures::io::AsyncReadExt;
use protobuf::Message; use protobuf::Message;
use std::{pin::Pin, task::Context};
use super::NoiseOutput; use super::NoiseOutput;
/// The identity of the remote established during a handshake. /// The identity of the remote established during a handshake.
@ -86,129 +87,162 @@ pub enum IdentityExchange {
None { remote: identity::PublicKey } None { remote: identity::PublicKey }
} }
impl<T, C> Handshake<T, C> /// A future performing a Noise handshake pattern.
where pub struct Handshake<T, C>(
T: AsyncRead + AsyncWrite + Send + 'static, Pin<Box<dyn Future<
C: Protocol<C> + AsRef<[u8]> + Send + 'static, Output = Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError>,
{ > + Send>>
/// Creates an authenticated Noise handshake for the initiator of a );
/// single roundtrip (2 message) handshake pattern.
/// impl<T, C> Future for Handshake<T, C> {
/// Subject to the chosen [`IdentityExchange`], this message sequence type Output = Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError>;
/// identifies the local node to the remote with the first message payload
/// (i.e. unencrypted) and expects the remote to identify itself in the fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> task::Poll<Self::Output> {
/// second message payload. Pin::new(&mut self.0).poll(ctx)
/// }
/// This message sequence is suitable for authenticated 2-message Noise handshake }
/// patterns where the static keys of the initiator and responder are either
/// known (i.e. appear in the pre-message pattern) or are sent with /// Creates an authenticated Noise handshake for the initiator of a
/// the first and second message, respectively (e.g. `IK` or `IX`). /// single roundtrip (2 message) handshake pattern.
/// ///
/// ```raw /// Subject to the chosen [`IdentityExchange`], this message sequence
/// initiator -{id}-> responder /// identifies the local node to the remote with the first message payload
/// initiator <-{id}- responder /// (i.e. unencrypted) and expects the remote to identify itself in the
/// ``` /// second message payload.
pub fn rt1_initiator( ///
/// This message sequence is suitable for authenticated 2-message Noise handshake
/// patterns where the static keys of the initiator and responder are either
/// known (i.e. appear in the pre-message pattern) or are sent with
/// the first and second message, respectively (e.g. `IK` or `IX`).
///
/// ```raw
/// initiator -{id}-> responder
/// initiator <-{id}- responder
/// ```
pub fn rt1_initiator<T, C>(
io: T, io: T,
session: Result<snow::Session, NoiseError>, session: Result<snow::Session, NoiseError>,
identity: KeypairIdentity, identity: KeypairIdentity,
identity_x: IdentityExchange identity_x: IdentityExchange
) -> Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError> { ) -> Handshake<T, C>
let mut state = State::new(io, session, identity, identity_x); where
T: AsyncWrite + AsyncRead + Send + Unpin + 'static,
C: Protocol<C> + AsRef<[u8]>
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x)?;
send_identity(&mut state).await?; send_identity(&mut state).await?;
recv_identity(&mut state).await?; recv_identity(&mut state).await?;
state.finish.await state.finish()
} }))
}
/// Creates an authenticated Noise handshake for the responder of a /// Creates an authenticated Noise handshake for the responder of a
/// single roundtrip (2 message) handshake pattern. /// single roundtrip (2 message) handshake pattern.
/// ///
/// Subject to the chosen [`IdentityExchange`], this message sequence expects the /// Subject to the chosen [`IdentityExchange`], this message sequence expects the
/// remote to identify itself in the first message payload (i.e. unencrypted) /// remote to identify itself in the first message payload (i.e. unencrypted)
/// and identifies the local node to the remote in the second message payload. /// and identifies the local node to the remote in the second message payload.
/// ///
/// This message sequence is suitable for authenticated 2-message Noise handshake /// This message sequence is suitable for authenticated 2-message Noise handshake
/// patterns where the static keys of the initiator and responder are either /// patterns where the static keys of the initiator and responder are either
/// known (i.e. appear in the pre-message pattern) or are sent with the first /// known (i.e. appear in the pre-message pattern) or are sent with the first
/// and second message, respectively (e.g. `IK` or `IX`). /// and second message, respectively (e.g. `IK` or `IX`).
/// ///
/// ```raw /// ```raw
/// initiator -{id}-> responder /// initiator -{id}-> responder
/// initiator <-{id}- responder /// initiator <-{id}- responder
/// ``` /// ```
pub fn rt1_responder( pub fn rt1_responder<T, C>(
io: T, io: T,
session: Result<snow::Session, NoiseError>, session: Result<snow::Session, NoiseError>,
identity: KeypairIdentity, identity: KeypairIdentity,
identity_x: IdentityExchange, identity_x: IdentityExchange,
) -> Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError> { ) -> Handshake<T, C>
let mut state = State::new(io, session, identity, identity_x); where
T: AsyncWrite + AsyncRead + Send + Unpin + 'static,
C: Protocol<C> + AsRef<[u8]>
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x)?;
recv_identity(&mut state).await?; recv_identity(&mut state).await?;
send_identity(&mut state).await?; send_identity(&mut state).await?;
state.finish.await state.finish()
} }))
}
/// Creates an authenticated Noise handshake for the initiator of a /// Creates an authenticated Noise handshake for the initiator of a
/// 1.5-roundtrip (3 message) handshake pattern. /// 1.5-roundtrip (3 message) handshake pattern.
/// ///
/// Subject to the chosen [`IdentityExchange`], this message sequence expects /// Subject to the chosen [`IdentityExchange`], this message sequence expects
/// the remote to identify itself in the second message payload and /// the remote to identify itself in the second message payload and
/// identifies the local node to the remote in the third message payload. /// identifies the local node to the remote in the third message payload.
/// The first (unencrypted) message payload is always empty. /// The first (unencrypted) message payload is always empty.
/// ///
/// This message sequence is suitable for authenticated 3-message Noise handshake /// This message sequence is suitable for authenticated 3-message Noise handshake
/// patterns where the static keys of the responder and initiator are either known /// patterns where the static keys of the responder and initiator are either known
/// (i.e. appear in the pre-message pattern) or are sent with the second and third /// (i.e. appear in the pre-message pattern) or are sent with the second and third
/// message, respectively (e.g. `XX`). /// message, respectively (e.g. `XX`).
/// ///
/// ```raw /// ```raw
/// initiator --{}--> responder /// initiator --{}--> responder
/// initiator <-{id}- responder /// initiator <-{id}- responder
/// initiator -{id}-> responder /// initiator -{id}-> responder
/// ``` /// ```
pub fn rt15_initiator( pub fn rt15_initiator<T, C>(
io: T, io: T,
session: Result<snow::Session, NoiseError>, session: Result<snow::Session, NoiseError>,
identity: KeypairIdentity, identity: KeypairIdentity,
identity_x: IdentityExchange identity_x: IdentityExchange
) -> Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError> { ) -> Handshake<T, C>
let mut state = State::new(io, session, identity, identity_x); where
T: AsyncWrite + AsyncRead + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]>
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x)?;
send_empty(&mut state).await?; send_empty(&mut state).await?;
send_identity(&mut state).await?;
recv_identity(&mut state).await?; recv_identity(&mut state).await?;
state.finish.await send_identity(&mut state).await?;
} state.finish()
}))
}
/// Creates an authenticated Noise handshake for the responder of a /// Creates an authenticated Noise handshake for the responder of a
/// 1.5-roundtrip (3 message) handshake pattern. /// 1.5-roundtrip (3 message) handshake pattern.
/// ///
/// Subject to the chosen [`IdentityExchange`], this message sequence /// Subject to the chosen [`IdentityExchange`], this message sequence
/// identifies the local node in the second message payload and expects /// identifies the local node in the second message payload and expects
/// the remote to identify itself in the third message payload. The first /// the remote to identify itself in the third message payload. The first
/// (unencrypted) message payload is always empty. /// (unencrypted) message payload is always empty.
/// ///
/// This message sequence is suitable for authenticated 3-message Noise handshake /// This message sequence is suitable for authenticated 3-message Noise handshake
/// patterns where the static keys of the responder and initiator are either known /// patterns where the static keys of the responder and initiator are either known
/// (i.e. appear in the pre-message pattern) or are sent with the second and third /// (i.e. appear in the pre-message pattern) or are sent with the second and third
/// message, respectively (e.g. `XX`). /// message, respectively (e.g. `XX`).
/// ///
/// ```raw /// ```raw
/// initiator --{}--> responder /// initiator --{}--> responder
/// initiator <-{id}- responder /// initiator <-{id}- responder
/// initiator -{id}-> responder /// initiator -{id}-> responder
/// ``` /// ```
pub async fn rt15_responder( pub fn rt15_responder<T, C>(
io: T, io: T,
session: Result<snow::Session, NoiseError>, session: Result<snow::Session, NoiseError>,
identity: KeypairIdentity, identity: KeypairIdentity,
identity_x: IdentityExchange identity_x: IdentityExchange
) -> Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError> { ) -> Handshake<T, C>
let mut state = State::new(io, session, identity, identity_x); where
T: AsyncWrite + AsyncRead + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]>
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x)?;
recv_empty(&mut state).await?; recv_empty(&mut state).await?;
send_identity(&mut state).await?; send_identity(&mut state).await?;
recv_identity(&mut state).await?; recv_identity(&mut state).await?;
state.finish().await state.finish()
} }))
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -240,14 +274,14 @@ impl<T> State<T> {
session: Result<snow::Session, NoiseError>, session: Result<snow::Session, NoiseError>,
identity: KeypairIdentity, identity: KeypairIdentity,
identity_x: IdentityExchange identity_x: IdentityExchange
) -> FutureResult<Self, NoiseError> { ) -> Result<Self, NoiseError> {
let (id_remote_pubkey, send_identity) = match identity_x { let (id_remote_pubkey, send_identity) = match identity_x {
IdentityExchange::Mutual => (None, true), IdentityExchange::Mutual => (None, true),
IdentityExchange::Send { remote } => (Some(remote), true), IdentityExchange::Send { remote } => (Some(remote), true),
IdentityExchange::Receive => (None, false), IdentityExchange::Receive => (None, false),
IdentityExchange::None { remote } => (Some(remote), false) IdentityExchange::None { remote } => (Some(remote), false)
}; };
future::result(session.map(|s| session.map(|s|
State { State {
identity, identity,
io: NoiseOutput::new(io, s), io: NoiseOutput::new(io, s),
@ -255,7 +289,7 @@ impl<T> State<T> {
id_remote_pubkey, id_remote_pubkey,
send_identity send_identity
} }
)) )
} }
} }
@ -263,19 +297,19 @@ impl<T> State<T>
{ {
/// Finish a handshake, yielding the established remote identity and the /// Finish a handshake, yielding the established remote identity and the
/// [`NoiseOutput`] for communicating on the encrypted channel. /// [`NoiseOutput`] for communicating on the encrypted channel.
fn finish<C>(self) -> FutureResult<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError> fn finish<C>(self) -> Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError>
where where
C: Protocol<C> + AsRef<[u8]> C: Protocol<C> + AsRef<[u8]>
{ {
let dh_remote_pubkey = match self.io.session.get_remote_static() { let dh_remote_pubkey = match self.io.session.get_remote_static() {
None => None, None => None,
Some(k) => match C::public_from_bytes(k) { Some(k) => match C::public_from_bytes(k) {
Err(e) => return future::err(e), Err(e) => return Err(e),
Ok(dh_pk) => Some(dh_pk) Ok(dh_pk) => Some(dh_pk)
} }
}; };
match self.io.session.into_transport_mode() { match self.io.session.into_transport_mode() {
Err(e) => future::err(e.into()), Err(e) => Err(e.into()),
Ok(s) => { Ok(s) => {
let remote = match (self.id_remote_pubkey, dh_remote_pubkey) { let remote = match (self.id_remote_pubkey, dh_remote_pubkey) {
(_, None) => RemoteIdentity::Unknown, (_, None) => RemoteIdentity::Unknown,
@ -284,11 +318,11 @@ impl<T> State<T>
if C::verify(&id_pk, &dh_pk, &self.dh_remote_pubkey_sig) { if C::verify(&id_pk, &dh_pk, &self.dh_remote_pubkey_sig) {
RemoteIdentity::IdentityKey(id_pk) RemoteIdentity::IdentityKey(id_pk)
} else { } else {
return future::err(NoiseError::InvalidKey) return Err(NoiseError::InvalidKey)
} }
} }
}; };
future::ok((remote, NoiseOutput { session: s, .. self.io })) Ok((remote, NoiseOutput { session: s, .. self.io }))
} }
} }
} }
@ -297,121 +331,72 @@ impl<T> State<T>
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Handshake Message Futures // Handshake Message Futures
// RecvEmpty -----------------------------------------------------------------
/// A future for receiving a Noise handshake message with an empty payload. /// A future for receiving a Noise handshake message with an empty payload.
///
/// Obtained from [`Handshake::recv_empty`].
async fn recv_empty<T>(state: &mut State<T>) -> Result<(), NoiseError> async fn recv_empty<T>(state: &mut State<T>) -> Result<(), NoiseError>
where where
T: AsyncRead T: AsyncRead + Unpin
{ {
state.io.read(&mut []).await?; state.io.read(&mut []).await?;
Ok(()) Ok(())
} }
// SendEmpty -----------------------------------------------------------------
/// A future for sending a Noise handshake message with an empty payload. /// A future for sending a Noise handshake message with an empty payload.
///
/// Obtained from [`Handshake::send_empty`].
async fn send_empty<T>(state: &mut State<T>) -> Result<(), NoiseError> async fn send_empty<T>(state: &mut State<T>) -> Result<(), NoiseError>
where where
T: AsyncWrite T: AsyncWrite + Unpin
{ {
state.write(&[]).await?; state.io.write(&[]).await?;
state.flush().await?; state.io.flush().await?;
Ok(()) Ok(())
} }
// RecvIdentity --------------------------------------------------------------
/// A future for receiving a Noise handshake message with a payload /// A future for receiving a Noise handshake message with a payload
/// identifying the remote. /// identifying the remote.
/// async fn recv_identity<T>(state: &mut State<T>) -> Result<(), NoiseError>
/// Obtained from [`Handshake::recv_identity`].
struct RecvIdentity<T> {
state: RecvIdentityState<T>
}
enum RecvIdentityState<T> {
Init(State<T>),
ReadPayloadLen(nio::ReadExact<State<T>, [u8; 2]>),
ReadPayload(nio::ReadExact<State<T>, Vec<u8>>),
Done
}
impl<T> Future for RecvIdentity<T>
where where
T: AsyncRead, T: AsyncRead + Unpin,
{ {
type Error = NoiseError; let mut len_buf = [0,0];
type Item = State<T>; state.io.read_exact(&mut len_buf).await?;
let len = u16::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0; len];
state.io.read_exact(&mut payload_buf).await?;
let pb: payload::Identity = protobuf::parse_from_bytes(&payload_buf)?;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.state, RecvIdentityState::Done) {
RecvIdentityState::Init(st) => {
self.state = RecvIdentityState::ReadPayloadLen(nio::read_exact(st, [0, 0]));
},
RecvIdentityState::ReadPayloadLen(mut read_len) => {
if let Async::Ready((st, bytes)) = read_len.poll()? {
let len = u16::from_be_bytes(bytes) as usize;
let buf = vec![0; len];
self.state = RecvIdentityState::ReadPayload(nio::read_exact(st, buf));
} else {
self.state = RecvIdentityState::ReadPayloadLen(read_len);
return Ok(Async::NotReady);
}
},
RecvIdentityState::ReadPayload(mut read_payload) => {
if let Async::Ready((mut st, bytes)) = read_payload.poll()? {
let pb: payload::Identity = protobuf::parse_from_bytes(&bytes)?;
if !pb.pubkey.is_empty() { if !pb.pubkey.is_empty() {
let pk = identity::PublicKey::from_protobuf_encoding(pb.get_pubkey()) let pk = identity::PublicKey::from_protobuf_encoding(pb.get_pubkey())
.map_err(|_| NoiseError::InvalidKey)?; .map_err(|_| NoiseError::InvalidKey)?;
if let Some(ref k) = st.id_remote_pubkey { if let Some(ref k) = state.id_remote_pubkey {
if k != &pk { if k != &pk {
return Err(NoiseError::InvalidKey) return Err(NoiseError::InvalidKey)
} }
} }
st.id_remote_pubkey = Some(pk); state.id_remote_pubkey = Some(pk);
} }
if !pb.signature.is_empty() { if !pb.signature.is_empty() {
st.dh_remote_pubkey_sig = Some(pb.signature) state.dh_remote_pubkey_sig = Some(pb.signature);
}
return Ok(Async::Ready(st))
} else {
self.state = RecvIdentityState::ReadPayload(read_payload);
return Ok(Async::NotReady)
}
},
RecvIdentityState::Done => panic!("RecvIdentity polled after completion")
}
}
} }
Ok(())
} }
// SendIdentity --------------------------------------------------------------
/// Send a Noise handshake message with a payload identifying the local node to the remote. /// Send a Noise handshake message with a payload identifying the local node to the remote.
///
/// Obtained from [`Handshake::send_identity`].
async fn send_identity<T>(state: &mut State<T>) -> Result<(), NoiseError> async fn send_identity<T>(state: &mut State<T>) -> Result<(), NoiseError>
where where
T: AsyncWrite T: AsyncWrite + Unpin,
{ {
let mut pb = payload::Identity::new(); let mut pb = payload::Identity::new();
if st.send_identity { if state.send_identity {
pb.set_pubkey(st.identity.public.clone().into_protobuf_encoding()); pb.set_pubkey(state.identity.public.clone().into_protobuf_encoding());
} }
if let Some(ref sig) = st.identity.signature { if let Some(ref sig) = state.identity.signature {
pb.set_signature(sig.clone()); pb.set_signature(sig.clone());
} }
let pb_bytes = pb.write_to_bytes()?; let pb_bytes = pb.write_to_bytes()?;
let len = (pb_bytes.len() as u16).to_be_bytes(); let len = (pb_bytes.len() as u16).to_be_bytes();
st.write_all(&len).await?; state.io.write_all(&len).await?;
st.write_all(&pb_bytes).await?; state.io.write_all(&pb_bytes).await?;
st.flush().await?; state.io.flush().await?;
Ok(()) Ok(())
} }

View File

@ -25,11 +25,11 @@
//! //!
//! This crate provides `libp2p_core::InboundUpgrade` and `libp2p_core::OutboundUpgrade` //! This crate provides `libp2p_core::InboundUpgrade` and `libp2p_core::OutboundUpgrade`
//! implementations for various noise handshake patterns (currently `IK`, `IX`, and `XX`) //! implementations for various noise handshake patterns (currently `IK`, `IX`, and `XX`)
//! over a particular choice of DH key agreement (currently only X25519). //! over a particular choice of DiffieHellman key agreement (currently only X25519).
//! //!
//! All upgrades produce as output a pair, consisting of the remote's static public key //! All upgrades produce as output a pair, consisting of the remote's static public key
//! and a `NoiseOutput` which represents the established cryptographic session with the //! and a `NoiseOutput` which represents the established cryptographic session with the
//! remote, implementing `tokio_io::AsyncRead` and `tokio_io::AsyncWrite`. //! remote, implementing `futures::io::AsyncRead` and `futures::io::AsyncWrite`.
//! //!
//! # Usage //! # Usage
//! //!
@ -57,12 +57,14 @@ mod protocol;
pub use error::NoiseError; pub use error::NoiseError;
pub use io::NoiseOutput; pub use io::NoiseOutput;
pub use io::handshake::{RemoteIdentity, IdentityExchange}; pub use io::handshake;
pub use io::handshake::{Handshake, RemoteIdentity, IdentityExchange};
pub use protocol::{Keypair, AuthenticKeypair, KeypairIdentity, PublicKey, SecretKey}; pub use protocol::{Keypair, AuthenticKeypair, KeypairIdentity, PublicKey, SecretKey};
pub use protocol::{Protocol, ProtocolParams, x25519::X25519, IX, IK, XX}; pub use protocol::{Protocol, ProtocolParams, x25519::X25519, IX, IK, XX};
use futures::prelude::*;
use libp2p_core::{identity, PeerId, UpgradeInfo, InboundUpgrade, OutboundUpgrade, Negotiated}; use libp2p_core::{identity, PeerId, UpgradeInfo, InboundUpgrade, OutboundUpgrade, Negotiated};
use tokio_io::{AsyncRead, AsyncWrite}; use std::pin::Pin;
use zeroize::Zeroize; use zeroize::Zeroize;
/// The protocol upgrade configuration. /// The protocol upgrade configuration.
@ -157,7 +159,7 @@ where
impl<T, C> InboundUpgrade<T> for NoiseConfig<IX, C> impl<T, C> InboundUpgrade<T> for NoiseConfig<IX, C>
where where
NoiseConfig<IX, C>: UpgradeInfo, NoiseConfig<IX, C>: UpgradeInfo,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>); type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>);
@ -169,7 +171,7 @@ where
.local_private_key(self.dh_keys.secret().as_ref()) .local_private_key(self.dh_keys.secret().as_ref())
.build_responder() .build_responder()
.map_err(NoiseError::from); .map_err(NoiseError::from);
Handshake::rt1_responder(socket, session, handshake::rt1_responder(socket, session,
self.dh_keys.into_identity(), self.dh_keys.into_identity(),
IdentityExchange::Mutual) IdentityExchange::Mutual)
} }
@ -178,7 +180,7 @@ where
impl<T, C> OutboundUpgrade<T> for NoiseConfig<IX, C> impl<T, C> OutboundUpgrade<T> for NoiseConfig<IX, C>
where where
NoiseConfig<IX, C>: UpgradeInfo, NoiseConfig<IX, C>: UpgradeInfo,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>); type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>);
@ -190,7 +192,7 @@ where
.local_private_key(self.dh_keys.secret().as_ref()) .local_private_key(self.dh_keys.secret().as_ref())
.build_initiator() .build_initiator()
.map_err(NoiseError::from); .map_err(NoiseError::from);
Handshake::rt1_initiator(socket, session, handshake::rt1_initiator(socket, session,
self.dh_keys.into_identity(), self.dh_keys.into_identity(),
IdentityExchange::Mutual) IdentityExchange::Mutual)
} }
@ -201,7 +203,7 @@ where
impl<T, C> InboundUpgrade<T> for NoiseConfig<XX, C> impl<T, C> InboundUpgrade<T> for NoiseConfig<XX, C>
where where
NoiseConfig<XX, C>: UpgradeInfo, NoiseConfig<XX, C>: UpgradeInfo,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>); type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>);
@ -213,7 +215,7 @@ where
.local_private_key(self.dh_keys.secret().as_ref()) .local_private_key(self.dh_keys.secret().as_ref())
.build_responder() .build_responder()
.map_err(NoiseError::from); .map_err(NoiseError::from);
Handshake::rt15_responder(socket, session, handshake::rt15_responder(socket, session,
self.dh_keys.into_identity(), self.dh_keys.into_identity(),
IdentityExchange::Mutual) IdentityExchange::Mutual)
} }
@ -222,7 +224,7 @@ where
impl<T, C> OutboundUpgrade<T> for NoiseConfig<XX, C> impl<T, C> OutboundUpgrade<T> for NoiseConfig<XX, C>
where where
NoiseConfig<XX, C>: UpgradeInfo, NoiseConfig<XX, C>: UpgradeInfo,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>); type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>);
@ -234,7 +236,7 @@ where
.local_private_key(self.dh_keys.secret().as_ref()) .local_private_key(self.dh_keys.secret().as_ref())
.build_initiator() .build_initiator()
.map_err(NoiseError::from); .map_err(NoiseError::from);
Handshake::rt15_initiator(socket, session, handshake::rt15_initiator(socket, session,
self.dh_keys.into_identity(), self.dh_keys.into_identity(),
IdentityExchange::Mutual) IdentityExchange::Mutual)
} }
@ -245,7 +247,7 @@ where
impl<T, C> InboundUpgrade<T> for NoiseConfig<IK, C> impl<T, C> InboundUpgrade<T> for NoiseConfig<IK, C>
where where
NoiseConfig<IK, C>: UpgradeInfo, NoiseConfig<IK, C>: UpgradeInfo,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>); type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>);
@ -257,7 +259,7 @@ where
.local_private_key(self.dh_keys.secret().as_ref()) .local_private_key(self.dh_keys.secret().as_ref())
.build_responder() .build_responder()
.map_err(NoiseError::from); .map_err(NoiseError::from);
Handshake::rt1_responder(socket, session, handshake::rt1_responder(socket, session,
self.dh_keys.into_identity(), self.dh_keys.into_identity(),
IdentityExchange::Receive) IdentityExchange::Receive)
} }
@ -266,7 +268,7 @@ where
impl<T, C> OutboundUpgrade<T> for NoiseConfig<IK, C, (PublicKey<C>, identity::PublicKey)> impl<T, C> OutboundUpgrade<T> for NoiseConfig<IK, C, (PublicKey<C>, identity::PublicKey)>
where where
NoiseConfig<IK, C, (PublicKey<C>, identity::PublicKey)>: UpgradeInfo, NoiseConfig<IK, C, (PublicKey<C>, identity::PublicKey)>: UpgradeInfo,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>); type Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>);
@ -279,7 +281,7 @@ where
.remote_public_key(self.remote.0.as_ref()) .remote_public_key(self.remote.0.as_ref())
.build_initiator() .build_initiator()
.map_err(NoiseError::from); .map_err(NoiseError::from);
Handshake::rt1_initiator(socket, session, handshake::rt1_initiator(socket, session,
self.dh_keys.into_identity(), self.dh_keys.into_identity(),
IdentityExchange::Send { remote: self.remote.1 }) IdentityExchange::Send { remote: self.remote.1 })
} }
@ -319,23 +321,20 @@ where
NoiseConfig<P, C, R>: UpgradeInfo + InboundUpgrade<T, NoiseConfig<P, C, R>: UpgradeInfo + InboundUpgrade<T,
Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>), Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>),
Error = NoiseError Error = NoiseError
>, > + 'static,
<NoiseConfig<P, C, R> as InboundUpgrade<T>>::Future: Send,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (PeerId, NoiseOutput<Negotiated<T>>); type Output = (PeerId, NoiseOutput<Negotiated<T>>);
type Error = NoiseError; type Error = NoiseError;
type Future = future::AndThen< type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
<NoiseConfig<P, C, R> as InboundUpgrade<T>>::Future,
FutureResult<Self::Output, Self::Error>,
fn((RemoteIdentity<C>, NoiseOutput<Negotiated<T>>)) -> FutureResult<Self::Output, Self::Error>
>;
fn upgrade_inbound(self, socket: Negotiated<T>, info: Self::Info) -> Self::Future { fn upgrade_inbound(self, socket: Negotiated<T>, info: Self::Info) -> Self::Future {
self.config.upgrade_inbound(socket, info) Box::pin(self.config.upgrade_inbound(socket, info)
.and_then(|(remote, io)| future::result(match remote { .and_then(|(remote, io)| match remote {
RemoteIdentity::IdentityKey(pk) => Ok((pk.into_peer_id(), io)), RemoteIdentity::IdentityKey(pk) => future::ok((pk.into_peer_id(), io)),
_ => Err(NoiseError::AuthenticationFailed) _ => future::err(NoiseError::AuthenticationFailed)
})) }))
} }
} }
@ -345,24 +344,20 @@ where
NoiseConfig<P, C, R>: UpgradeInfo + OutboundUpgrade<T, NoiseConfig<P, C, R>: UpgradeInfo + OutboundUpgrade<T,
Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>), Output = (RemoteIdentity<C>, NoiseOutput<Negotiated<T>>),
Error = NoiseError Error = NoiseError
>, > + 'static,
<NoiseConfig<P, C, R> as OutboundUpgrade<T>>::Future: Send,
T: AsyncRead + AsyncWrite + Send + 'static, T: AsyncRead + AsyncWrite + Send + 'static,
C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static, C: Protocol<C> + AsRef<[u8]> + Zeroize + Send + 'static,
{ {
type Output = (PeerId, NoiseOutput<Negotiated<T>>); type Output = (PeerId, NoiseOutput<Negotiated<T>>);
type Error = NoiseError; type Error = NoiseError;
type Future = future::AndThen< type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
<NoiseConfig<P, C, R> as OutboundUpgrade<T>>::Future,
FutureResult<Self::Output, Self::Error>,
fn((RemoteIdentity<C>, NoiseOutput<Negotiated<T>>)) -> FutureResult<Self::Output, Self::Error>
>;
fn upgrade_outbound(self, socket: Negotiated<T>, info: Self::Info) -> Self::Future { fn upgrade_outbound(self, socket: Negotiated<T>, info: Self::Info) -> Self::Future {
self.config.upgrade_outbound(socket, info) Box::pin(self.config.upgrade_outbound(socket, info)
.and_then(|(remote, io)| future::result(match remote { .and_then(|(remote, io)| match remote {
RemoteIdentity::IdentityKey(pk) => Ok((pk.into_peer_id(), io)), RemoteIdentity::IdentityKey(pk) => future::ok((pk.into_peer_id(), io)),
_ => Err(NoiseError::AuthenticationFailed) _ => future::err(NoiseError::AuthenticationFailed)
})) }))
} }
} }

View File

@ -26,7 +26,6 @@ use libp2p_noise::{Keypair, X25519, NoiseConfig, RemoteIdentity, NoiseError, Noi
use libp2p_tcp::{TcpConfig, TcpTransStream}; use libp2p_tcp::{TcpConfig, TcpTransStream};
use log::info; use log::info;
use quickcheck::QuickCheck; use quickcheck::QuickCheck;
use tokio::{self, io};
#[allow(dead_code)] #[allow(dead_code)]
fn core_upgrade_compat() { fn core_upgrade_compat() {
@ -113,9 +112,9 @@ fn ik_xx() {
let server_transport = TcpConfig::new() let server_transport = TcpConfig::new()
.and_then(move |output, endpoint| { .and_then(move |output, endpoint| {
if endpoint.is_listener() { if endpoint.is_listener() {
Either::A(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) Either::Left(apply_inbound(output, NoiseConfig::ik_listener(server_dh)))
} else { } else {
Either::B(apply_outbound(output, NoiseConfig::xx(server_dh))) Either::Right(apply_outbound(output, NoiseConfig::xx(server_dh)))
} }
}) })
.and_then(move |out, _| expect_identity(out, &client_id_public)); .and_then(move |out, _| expect_identity(out, &client_id_public));
@ -125,10 +124,10 @@ fn ik_xx() {
let client_transport = TcpConfig::new() let client_transport = TcpConfig::new()
.and_then(move |output, endpoint| { .and_then(move |output, endpoint| {
if endpoint.is_dialer() { if endpoint.is_dialer() {
Either::A(apply_outbound(output, Either::Left(apply_outbound(output,
NoiseConfig::ik_dialer(client_dh, server_id_public, server_dh_public))) NoiseConfig::ik_dialer(client_dh, server_id_public, server_dh_public)))
} else { } else {
Either::B(apply_inbound(output, NoiseConfig::xx(client_dh))) Either::Right(apply_inbound(output, NoiseConfig::xx(client_dh)))
} }
}) })
.and_then(move |out, _| expect_identity(out, &server_id_public2)); .and_then(move |out, _| expect_identity(out, &server_id_public2));
@ -145,55 +144,63 @@ fn run<T, U>(server_transport: T, client_transport: U, message1: Vec<u8>)
where where
T: Transport<Output = Output>, T: Transport<Output = Output>,
T::Dial: Send + 'static, T::Dial: Send + 'static,
T::Listener: Send + 'static, T::Listener: Send + Unpin + futures::stream::TryStream + 'static,
T::ListenerUpgrade: Send + 'static, T::ListenerUpgrade: Send + 'static,
U: Transport<Output = Output>, U: Transport<Output = Output>,
U::Dial: Send + 'static, U::Dial: Send + 'static,
U::Listener: Send + 'static, U::Listener: Send + 'static,
U::ListenerUpgrade: Send + 'static, U::ListenerUpgrade: Send + 'static,
{ {
let message2 = message1.clone(); futures::executor::block_on(async {
let mut message2 = message1.clone();
let mut server = server_transport let mut server: T::Listener = server_transport
.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap())
.unwrap(); .unwrap();
let server_address = server.by_ref().wait() let server_address = server.try_next()
.next() .await
.expect("some event") .expect("some event")
.expect("no error") .expect("no error")
.into_new_address() .into_new_address()
.expect("listen address"); .expect("listen address");
let server = server.take(1) let client_fut = async {
.filter_map(ListenerEvent::into_upgrade) let mut client_session = client_transport.dial(server_address.clone())
.and_then(|client| client.0) .unwrap()
.map_err(|e| panic!("server error: {}", e)) .await
.and_then(|(_, client)| { .map(|(_, session)| session)
.expect("no error");
client_session.write_all(&mut message2).await.expect("no error");
client_session.flush().await.expect("no error");
};
let server_fut = async {
let mut server_session = server.try_next()
.await
.expect("some event")
.map(ListenerEvent::into_upgrade)
.expect("no error")
.map(|client| client.0)
.expect("listener upgrade")
.await
.map(|(_, session)| session)
.expect("no error");
let mut server_buffer = vec![];
info!("server: reading message"); info!("server: reading message");
io::read_to_end(client, Vec::new()) server_session.read_to_end(&mut server_buffer).await.expect("no error");
assert_eq!(server_buffer, message1);
};
futures::future::join(server_fut, client_fut).await;
}) })
.for_each(move |msg| {
assert_eq!(msg.1, message1);
Ok(())
});
let client = client_transport.dial(server_address.clone()).unwrap()
.map_err(|e| panic!("client error: {}", e))
.and_then(move |(_, server)| {
io::write_all(server, message2).and_then(|(client, _)| io::flush(client))
})
.map(|_| ());
let future = client.join(server)
.map_err(|e| panic!("{:?}", e))
.map(|_| ());
tokio::run(future)
} }
fn expect_identity(output: Output, pk: &identity::PublicKey) fn expect_identity(output: Output, pk: &identity::PublicKey)
-> impl Future<Item = Output, Error = NoiseError> -> impl Future<Output = Result<Output, NoiseError>>
{ {
match output.0 { match output.0 {
RemoteIdentity::IdentityKey(ref k) if k == pk => future::ok(output), RemoteIdentity::IdentityKey(ref k) if k == pk => future::ok(output),