diff --git a/core/src/protocols_handler/mod.rs b/core/src/protocols_handler/mod.rs index 37c8de94..ff4e922e 100644 --- a/core/src/protocols_handler/mod.rs +++ b/core/src/protocols_handler/mod.rs @@ -40,7 +40,7 @@ use crate::upgrade::{ UpgradeError, }; use futures::prelude::*; -use std::{error, fmt, time::Duration, time::Instant}; +use std::{cmp::Ordering, error, fmt, time::Duration, time::Instant}; use tokio_io::{AsyncRead, AsyncWrite}; pub use self::dummy::DummyProtocolsHandler; @@ -152,8 +152,8 @@ pub trait ProtocolsHandler { /// On the other hand, the return value is only an indication and doesn't mean that the user /// will not call `shutdown()`. /// - /// When multiple `ProtocolsHandler` are combined together, they should use return the largest - /// value of the two, or `Forever` if either returns `Forever`. + /// When multiple `ProtocolsHandler` are combined together, the largest `KeepAlive` should be + /// used. /// /// The result of this method should be checked every time `poll()` is invoked. /// @@ -434,3 +434,22 @@ impl KeepAlive { } } } + +impl PartialOrd for KeepAlive { + fn partial_cmp(&self, other: &KeepAlive) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for KeepAlive { + fn cmp(&self, other: &KeepAlive) -> Ordering { + use self::KeepAlive::*; + + match (self, other) { + (Now, Now) | (Forever, Forever) => Ordering::Equal, + (Now, _) | (_, Forever) => Ordering::Less, + (_, Now) | (Forever, _) => Ordering::Greater, + (Until(expiration), Until(other_expiration)) => expiration.cmp(other_expiration), + } + } +} diff --git a/core/src/protocols_handler/select.rs b/core/src/protocols_handler/select.rs index 392e143f..cf0c7cce 100644 --- a/core/src/protocols_handler/select.rs +++ b/core/src/protocols_handler/select.rs @@ -210,11 +210,7 @@ where #[inline] fn connection_keep_alive(&self) -> KeepAlive { - match (self.proto1.connection_keep_alive(), self.proto2.connection_keep_alive()) { - (KeepAlive::Forever, _) | (_, KeepAlive::Forever) => KeepAlive::Forever, - (a, KeepAlive::Now) | (KeepAlive::Now, a) => a, - (KeepAlive::Until(a), KeepAlive::Until(b)) => KeepAlive::Until(cmp::max(a, b)), - } + cmp::max(self.proto1.connection_keep_alive(), self.proto2.connection_keep_alive()) } #[inline]