// Copyright 2019 Parity Technologies (UK) Ltd. // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. use bytes::BytesMut; use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError}; use futures::{prelude::*, Async, try_ready}; use log::debug; use tokio_io::{AsyncRead, AsyncWrite}; use std::{mem, io, fmt, error::Error}; /// An I/O stream that has settled on an (application-layer) protocol to use. /// /// A `Negotiated` represents an I/O stream that has _settled_ on a protocol /// to use. In particular, it is not implied that all of the protocol negotiation /// frames have yet been sent and / or received, just that the selected protocol /// is fully determined. This is to allow the last protocol negotiation frames /// sent by a peer to be combined in a single write, possibly piggy-backing /// data from the negotiated protocol on top. /// /// Reading from a `Negotiated` I/O stream that still has pending negotiation /// protocol data to send implicitly triggers flushing of all yet unsent data. #[derive(Debug)] pub struct Negotiated { state: State } /// A `Future` that waits on the completion of protocol negotiation. #[derive(Debug)] pub struct NegotiatedComplete { inner: Option> } impl Future for NegotiatedComplete { type Item = Negotiated; type Error = NegotiationError; fn poll(&mut self) -> Poll { let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); if io.poll()?.is_not_ready() { self.inner = Some(io); return Ok(Async::NotReady) } return Ok(Async::Ready(io)) } } impl Negotiated { /// Creates a `Negotiated` in state [`State::Complete`], possibly /// with `remaining` data to be sent. pub(crate) fn completed(io: TInner, remaining: BytesMut) -> Self { Negotiated { state: State::Completed { io, remaining } } } /// Creates a `Negotiated` in state [`State::Expecting`] that is still /// expecting confirmation of the given `protocol`. pub(crate) fn expecting(io: MessageReader, protocol: Protocol) -> Self { Negotiated { state: State::Expecting { io, protocol } } } /// Polls the `Negotiated` for completion. fn poll(&mut self) -> Poll<(), NegotiationError> where TInner: AsyncRead + AsyncWrite { // Flush any pending negotiation data. match self.poll_flush() { Ok(Async::Ready(())) => {}, Ok(Async::NotReady) => return Ok(Async::NotReady), Err(e) => { // If the remote closed the stream, it is important to still // continue reading the data that was sent, if any. if e.kind() != io::ErrorKind::WriteZero { return Err(e.into()) } } } if let State::Completed { remaining, .. } = &mut self.state { let _ = remaining.take(); // Drop remaining data flushed above. return Ok(Async::Ready(())) } // Read outstanding protocol negotiation messages. loop { match mem::replace(&mut self.state, State::Invalid) { State::Expecting { mut io, protocol } => { let msg = match io.poll() { Ok(Async::Ready(Some(msg))) => msg, Ok(Async::NotReady) => { self.state = State::Expecting { io, protocol }; return Ok(Async::NotReady) } Ok(Async::Ready(None)) => { self.state = State::Expecting { io, protocol }; return Err(ProtocolError::IoError( io::ErrorKind::UnexpectedEof.into()).into()) } Err(err) => { self.state = State::Expecting { io, protocol }; return Err(err.into()) } }; if let Message::Header(Version::V1) = &msg { self.state = State::Expecting { io, protocol }; continue } if let Message::Protocol(p) = &msg { if p.as_ref() == protocol.as_ref() { debug!("Negotiated: Received confirmation for protocol: {}", p); let (io, remaining) = io.into_inner(); self.state = State::Completed { io, remaining }; return Ok(Async::Ready(())) } } return Err(NegotiationError::Failed) } _ => panic!("Negotiated: Invalid state") } } } /// Returns a `NegotiatedComplete` future that waits for protocol /// negotiation to complete. pub fn complete(self) -> NegotiatedComplete { NegotiatedComplete { inner: Some(self) } } } /// The states of a `Negotiated` I/O stream. #[derive(Debug)] enum State { /// In this state, a `Negotiated` is still expecting to /// receive confirmation of the protocol it as settled on. Expecting { io: MessageReader, protocol: Protocol }, /// In this state, a protocol has been agreed upon and may /// only be pending the sending of the final acknowledgement, /// which is prepended to / combined with the next write for /// efficiency. Completed { io: R, remaining: BytesMut }, /// Temporary state while moving the `io` resource from /// `Expecting` to `Completed`. Invalid, } impl io::Read for Negotiated where R: AsyncRead + AsyncWrite { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { if let State::Completed { io, remaining } = &mut self.state { // If protocol negotiation is complete and there is no // remaining data to be flushed, commence with reading. if remaining.is_empty() { return io.read(buf) } } // Poll the `Negotiated`, driving protocol negotiation to completion, // including flushing of any remaining data. let result = self.poll(); // There is still remaining data to be sent before data relating // to the negotiated protocol can be read. if let Ok(Async::NotReady) = result { return Err(io::ErrorKind::WouldBlock.into()) } if let Err(err) = result { return Err(err.into()) } } } } impl AsyncRead for Negotiated where TInner: AsyncRead + AsyncWrite { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { match &self.state { State::Completed { io, .. } => io.prepare_uninitialized_buffer(buf), State::Expecting { io, .. } => io.inner_ref().prepare_uninitialized_buffer(buf), State::Invalid => panic!("Negotiated: Invalid state") } } } impl io::Write for Negotiated where TInner: AsyncWrite { fn write(&mut self, buf: &[u8]) -> io::Result { match &mut self.state { State::Completed { io, ref mut remaining } => { if !remaining.is_empty() { // Try to write `buf` together with `remaining` for efficiency, // regardless of whether the underlying I/O stream is buffered. // Every call to `write` may imply a syscall and separate // network packet. let remaining_len = remaining.len(); remaining.extend_from_slice(buf); match io.write(&remaining) { Err(e) => { remaining.split_off(remaining_len); Err(e) } Ok(n) => { remaining.split_to(n); if !remaining.is_empty() { let written = if n < buf.len() { remaining.split_off(remaining_len); n } else { buf.len() }; debug_assert!(remaining.len() <= remaining_len); Ok(written) } else { Ok(buf.len()) } } } } else { io.write(buf) } }, State::Expecting { io, .. } => io.write(buf), State::Invalid => panic!("Negotiated: Invalid state") } } fn flush(&mut self) -> io::Result<()> { match &mut self.state { State::Completed { io, ref mut remaining } => { while !remaining.is_empty() { let n = io.write(remaining)?; if n == 0 { return Err(io::Error::new( io::ErrorKind::WriteZero, "Failed to write remaining buffer.")) } remaining.split_to(n); } io.flush() }, State::Expecting { io, .. } => io.flush(), State::Invalid => panic!("Negotiated: Invalid state") } } } impl AsyncWrite for Negotiated where TInner: AsyncWrite + AsyncRead { fn shutdown(&mut self) -> Poll<(), io::Error> { // Ensure all data has been flushed and expected negotiation messages // have been received. try_ready!(self.poll().map_err(Into::::into)); // Continue with the shutdown of the underlying I/O stream. match &mut self.state { State::Completed { io, .. } => io.shutdown(), State::Expecting { io, .. } => io.shutdown(), State::Invalid => panic!("Negotiated: Invalid state") } } } /// Error that can happen when negotiating a protocol with the remote. #[derive(Debug)] pub enum NegotiationError { /// A protocol error occurred during the negotiation. ProtocolError(ProtocolError), /// Protocol negotiation failed because no protocol could be agreed upon. Failed, } impl From for NegotiationError { fn from(err: ProtocolError) -> NegotiationError { NegotiationError::ProtocolError(err) } } impl From for NegotiationError { fn from(err: io::Error) -> NegotiationError { ProtocolError::from(err).into() } } impl Into for NegotiationError { fn into(self) -> io::Error { if let NegotiationError::ProtocolError(e) = self { return e.into() } io::Error::new(io::ErrorKind::Other, self) } } impl Error for NegotiationError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { NegotiationError::ProtocolError(err) => Some(err), _ => None, } } } impl fmt::Display for NegotiationError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(fmt, "{}", Error::description(self)) } } #[cfg(test)] mod tests { use super::*; use quickcheck::*; use std::io::Write; /// An I/O resource with a fixed write capacity (total and per write op). struct Capped { buf: Vec, step: usize } impl io::Write for Capped { fn write(&mut self, buf: &[u8]) -> io::Result { if self.buf.len() + buf.len() > self.buf.capacity() { return Err(io::ErrorKind::WriteZero.into()) } self.buf.write(&buf[.. usize::min(self.step, buf.len())]) } fn flush(&mut self) -> io::Result<()> { Ok(()) } } impl AsyncWrite for Capped { fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(().into()) } } #[test] fn write_remaining() { fn prop(rem: Vec, new: Vec, free: u8) -> TestResult { let cap = rem.len() + free as usize; let buf = Capped { buf: Vec::with_capacity(cap), step: free as usize }; let mut rem = BytesMut::from(rem); let mut io = Negotiated::completed(buf, rem.clone()); let mut written = 0; loop { // Write until `new` has been fully written or the capped buffer is // full (in which case the buffer should remain unchanged from the // last successful write). match io.write(&new[written..]) { Ok(n) => if let State::Completed { remaining, .. } = &io.state { if n == rem.len() + new[written..].len() { assert!(remaining.is_empty()) } else { assert!(remaining.len() <= rem.len()); } written += n; if written == new.len() { return TestResult::passed() } rem = remaining.clone(); } else { return TestResult::failed() } Err(_) => if let State::Completed { remaining, .. } = &io.state { assert!(rem.len() + new[written..].len() > cap); assert_eq!(remaining, &rem); return TestResult::passed() } else { return TestResult::failed() } } } } quickcheck(prop as fn(_,_,_) -> _) } }