rust-libp2p/swarm/src/stream_protocol.rs
Thomas Eizinger c93f753018
feat: replace ProtocolName with AsRef<str>
Previously, a protocol could be any sequence of bytes as long as it started with `/`. Now, we directly parse a protocol as `String` which enforces it to be valid UTF8.

To notify users of this change, we delete the `ProtocolName` trait. The new requirement is that users need to provide a type that implements `AsRef<str>`.

We also add a `StreamProtocol` newtype in `libp2p-swarm` which provides an easy way for users to ensure their protocol strings are compliant. The newtype enforces that protocol strings start with `/`. `StreamProtocol` also implements `AsRef<str>`, meaning users can directly use it in their upgrades.

`multistream-select` by itself only changes marginally with this patch. The only thing we enforce in the type-system is that protocols must implement `AsRef<str>`.

Resolves: #2831.

Pull-Request: #3746.
2023-05-04 04:47:11 +00:00

105 lines
2.7 KiB
Rust

use either::Either;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
/// Identifies a protocol for a stream.
///
/// libp2p nodes use stream protocols to negotiate what to do with a newly opened stream.
/// Stream protocols are string-based and must start with a forward slash: `/`.
#[derive(Debug, Clone, Eq)]
pub struct StreamProtocol {
inner: Either<&'static str, Arc<str>>,
}
impl StreamProtocol {
/// Construct a new protocol from a static string slice.
///
/// # Panics
///
/// This function panics if the protocol does not start with a forward slash: `/`.
pub const fn new(s: &'static str) -> Self {
match s.as_bytes() {
[b'/', ..] => {}
_ => panic!("Protocols should start with a /"),
}
StreamProtocol {
inner: Either::Left(s),
}
}
/// Attempt to construct a protocol from an owned string.
///
/// This function will fail if the protocol does not start with a forward slash: `/`.
/// Where possible, you should use [`StreamProtocol::new`] instead to avoid allocations.
pub fn try_from_owned(protocol: String) -> Result<Self, InvalidProtocol> {
if !protocol.starts_with('/') {
return Err(InvalidProtocol::missing_forward_slash());
}
Ok(StreamProtocol {
inner: Either::Right(Arc::from(protocol)), // FIXME: Can we somehow reuse the allocation from the owned string?
})
}
}
impl AsRef<str> for StreamProtocol {
fn as_ref(&self) -> &str {
either::for_both!(&self.inner, s => s)
}
}
impl fmt::Display for StreamProtocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
impl PartialEq<&str> for StreamProtocol {
fn eq(&self, other: &&str) -> bool {
self.as_ref() == *other
}
}
impl PartialEq<StreamProtocol> for &str {
fn eq(&self, other: &StreamProtocol) -> bool {
*self == other.as_ref()
}
}
impl PartialEq for StreamProtocol {
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl Hash for StreamProtocol {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state)
}
}
#[derive(Debug)]
pub struct InvalidProtocol {
// private field to prevent construction outside of this module
_private: (),
}
impl InvalidProtocol {
pub(crate) fn missing_forward_slash() -> Self {
InvalidProtocol { _private: () }
}
}
impl fmt::Display for InvalidProtocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"invalid protocol: string does not start with a forward slash"
)
}
}
impl std::error::Error for InvalidProtocol {}