555 lines
18 KiB
Rust
Raw Normal View History

use async_std::io;
use async_trait::async_trait;
use either::Either;
use futures::channel::{mpsc, oneshot};
use futures::prelude::*;
use libp2p::{
core::{
upgrade::{read_length_prefixed, write_length_prefixed, ProtocolName},
Multiaddr,
},
identity,
kad::{
record::store::MemoryStore, GetProvidersOk, Kademlia, KademliaEvent, QueryId, QueryResult,
},
multiaddr::Protocol,
noise,
request_response::{self, ProtocolSupport, RequestId, ResponseChannel},
swarm::{ConnectionHandlerUpgrErr, NetworkBehaviour, Swarm, SwarmBuilder, SwarmEvent},
tcp, yamux, PeerId, Transport,
};
use libp2p::core::upgrade::Version;
use std::collections::{hash_map, HashMap, HashSet};
use std::error::Error;
use std::iter;
/// Creates the network components, namely:
///
/// - The network client to interact with the network layer from anywhere
/// within your application.
///
/// - The network event stream, e.g. for incoming requests.
///
/// - The network task driving the network itself.
pub(crate) async fn new(
secret_key_seed: Option<u8>,
) -> Result<(Client, impl Stream<Item = Event>, EventLoop), Box<dyn Error>> {
// Create a public/private key pair, either random or based on a seed.
let id_keys = match secret_key_seed {
Some(seed) => {
let mut bytes = [0u8; 32];
bytes[0] = seed;
identity::Keypair::ed25519_from_bytes(bytes).unwrap()
}
None => identity::Keypair::generate_ed25519(),
};
let peer_id = id_keys.public().to_peer_id();
let transport = tcp::async_io::Transport::default()
.upgrade(Version::V1Lazy)
.authenticate(noise::Config::new(&id_keys)?)
.multiplex(yamux::YamuxConfig::default())
.boxed();
// Build the Swarm, connecting the lower layer transport logic with the
// higher layer network behaviour logic.
let swarm = SwarmBuilder::with_async_std_executor(
transport,
ComposedBehaviour {
kademlia: Kademlia::new(peer_id, MemoryStore::new(peer_id)),
request_response: request_response::Behaviour::new(
FileExchangeCodec(),
iter::once((FileExchangeProtocol(), ProtocolSupport::Full)),
Default::default(),
),
},
peer_id,
)
.build();
let (command_sender, command_receiver) = mpsc::channel(0);
let (event_sender, event_receiver) = mpsc::channel(0);
Ok((
Client {
sender: command_sender,
},
event_receiver,
EventLoop::new(swarm, command_receiver, event_sender),
))
}
#[derive(Clone)]
pub(crate) struct Client {
sender: mpsc::Sender<Command>,
}
impl Client {
/// Listen for incoming connections on the given address.
pub(crate) async fn start_listening(
&mut self,
addr: Multiaddr,
) -> Result<(), Box<dyn Error + Send>> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(Command::StartListening { addr, sender })
.await
.expect("Command receiver not to be dropped.");
receiver.await.expect("Sender not to be dropped.")
}
/// Dial the given peer at the given address.
pub(crate) async fn dial(
&mut self,
peer_id: PeerId,
peer_addr: Multiaddr,
) -> Result<(), Box<dyn Error + Send>> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(Command::Dial {
peer_id,
peer_addr,
sender,
})
.await
.expect("Command receiver not to be dropped.");
receiver.await.expect("Sender not to be dropped.")
}
/// Advertise the local node as the provider of the given file on the DHT.
pub(crate) async fn start_providing(&mut self, file_name: String) {
let (sender, receiver) = oneshot::channel();
self.sender
.send(Command::StartProviding { file_name, sender })
.await
.expect("Command receiver not to be dropped.");
receiver.await.expect("Sender not to be dropped.");
}
/// Find the providers for the given file on the DHT.
pub(crate) async fn get_providers(&mut self, file_name: String) -> HashSet<PeerId> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(Command::GetProviders { file_name, sender })
.await
.expect("Command receiver not to be dropped.");
receiver.await.expect("Sender not to be dropped.")
}
/// Request the content of the given file from the given peer.
pub(crate) async fn request_file(
&mut self,
peer: PeerId,
file_name: String,
) -> Result<Vec<u8>, Box<dyn Error + Send>> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(Command::RequestFile {
file_name,
peer,
sender,
})
.await
.expect("Command receiver not to be dropped.");
receiver.await.expect("Sender not be dropped.")
}
/// Respond with the provided file content to the given request.
pub(crate) async fn respond_file(
&mut self,
file: Vec<u8>,
channel: ResponseChannel<FileResponse>,
) {
self.sender
.send(Command::RespondFile { file, channel })
.await
.expect("Command receiver not to be dropped.");
}
}
pub(crate) struct EventLoop {
swarm: Swarm<ComposedBehaviour>,
command_receiver: mpsc::Receiver<Command>,
event_sender: mpsc::Sender<Event>,
pending_dial: HashMap<PeerId, oneshot::Sender<Result<(), Box<dyn Error + Send>>>>,
pending_start_providing: HashMap<QueryId, oneshot::Sender<()>>,
pending_get_providers: HashMap<QueryId, oneshot::Sender<HashSet<PeerId>>>,
pending_request_file:
HashMap<RequestId, oneshot::Sender<Result<Vec<u8>, Box<dyn Error + Send>>>>,
}
impl EventLoop {
fn new(
swarm: Swarm<ComposedBehaviour>,
command_receiver: mpsc::Receiver<Command>,
event_sender: mpsc::Sender<Event>,
) -> Self {
Self {
swarm,
command_receiver,
event_sender,
pending_dial: Default::default(),
pending_start_providing: Default::default(),
pending_get_providers: Default::default(),
pending_request_file: Default::default(),
}
}
pub(crate) async fn run(mut self) {
loop {
futures::select! {
event = self.swarm.next() => self.handle_event(event.expect("Swarm stream to be infinite.")).await ,
command = self.command_receiver.next() => match command {
Some(c) => self.handle_command(c).await,
// Command channel closed, thus shutting down the network event loop.
None=> return,
},
}
}
}
async fn handle_event(
&mut self,
event: SwarmEvent<ComposedEvent, Either<ConnectionHandlerUpgrErr<io::Error>, io::Error>>,
) {
match event {
SwarmEvent::Behaviour(ComposedEvent::Kademlia(
KademliaEvent::OutboundQueryProgressed {
id,
result: QueryResult::StartProviding(_),
..
},
)) => {
let sender: oneshot::Sender<()> = self
.pending_start_providing
.remove(&id)
.expect("Completed query to be previously pending.");
let _ = sender.send(());
}
SwarmEvent::Behaviour(ComposedEvent::Kademlia(
KademliaEvent::OutboundQueryProgressed {
id,
result:
QueryResult::GetProviders(Ok(GetProvidersOk::FoundProviders {
providers, ..
})),
..
},
)) => {
if let Some(sender) = self.pending_get_providers.remove(&id) {
sender.send(providers).expect("Receiver not to be dropped");
// Finish the query. We are only interested in the first result.
self.swarm
.behaviour_mut()
.kademlia
.query_mut(&id)
.unwrap()
.finish();
}
}
SwarmEvent::Behaviour(ComposedEvent::Kademlia(
KademliaEvent::OutboundQueryProgressed {
result:
QueryResult::GetProviders(Ok(GetProvidersOk::FinishedWithNoAdditionalRecord {
..
})),
..
},
)) => {}
SwarmEvent::Behaviour(ComposedEvent::Kademlia(_)) => {}
SwarmEvent::Behaviour(ComposedEvent::RequestResponse(
request_response::Event::Message { message, .. },
)) => match message {
request_response::Message::Request {
request, channel, ..
} => {
self.event_sender
.send(Event::InboundRequest {
request: request.0,
channel,
})
.await
.expect("Event receiver not to be dropped.");
}
request_response::Message::Response {
request_id,
response,
} => {
let _ = self
.pending_request_file
.remove(&request_id)
.expect("Request to still be pending.")
.send(Ok(response.0));
}
},
SwarmEvent::Behaviour(ComposedEvent::RequestResponse(
request_response::Event::OutboundFailure {
request_id, error, ..
},
)) => {
let _ = self
.pending_request_file
.remove(&request_id)
.expect("Request to still be pending.")
.send(Err(Box::new(error)));
}
SwarmEvent::Behaviour(ComposedEvent::RequestResponse(
request_response::Event::ResponseSent { .. },
)) => {}
SwarmEvent::NewListenAddr { address, .. } => {
let local_peer_id = *self.swarm.local_peer_id();
eprintln!(
"Local node is listening on {:?}",
address.with(Protocol::P2p(local_peer_id.into()))
);
}
SwarmEvent::IncomingConnection { .. } => {}
SwarmEvent::ConnectionEstablished {
peer_id, endpoint, ..
} => {
if endpoint.is_dialer() {
if let Some(sender) = self.pending_dial.remove(&peer_id) {
let _ = sender.send(Ok(()));
}
}
}
SwarmEvent::ConnectionClosed { .. } => {}
SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => {
if let Some(peer_id) = peer_id {
if let Some(sender) = self.pending_dial.remove(&peer_id) {
let _ = sender.send(Err(Box::new(error)));
}
}
}
SwarmEvent::IncomingConnectionError { .. } => {}
SwarmEvent::Dialing(peer_id) => eprintln!("Dialing {peer_id}"),
e => panic!("{e:?}"),
}
}
async fn handle_command(&mut self, command: Command) {
match command {
Command::StartListening { addr, sender } => {
let _ = match self.swarm.listen_on(addr) {
Ok(_) => sender.send(Ok(())),
Err(e) => sender.send(Err(Box::new(e))),
};
}
Command::Dial {
peer_id,
peer_addr,
sender,
} => {
if let hash_map::Entry::Vacant(e) = self.pending_dial.entry(peer_id) {
self.swarm
.behaviour_mut()
.kademlia
.add_address(&peer_id, peer_addr.clone());
match self
.swarm
.dial(peer_addr.with(Protocol::P2p(peer_id.into())))
{
Ok(()) => {
e.insert(sender);
}
Err(e) => {
let _ = sender.send(Err(Box::new(e)));
}
}
} else {
todo!("Already dialing peer.");
}
}
Command::StartProviding { file_name, sender } => {
let query_id = self
.swarm
.behaviour_mut()
.kademlia
.start_providing(file_name.into_bytes().into())
.expect("No store error.");
self.pending_start_providing.insert(query_id, sender);
}
Command::GetProviders { file_name, sender } => {
let query_id = self
.swarm
.behaviour_mut()
.kademlia
.get_providers(file_name.into_bytes().into());
self.pending_get_providers.insert(query_id, sender);
}
Command::RequestFile {
file_name,
peer,
sender,
} => {
let request_id = self
.swarm
.behaviour_mut()
.request_response
.send_request(&peer, FileRequest(file_name));
self.pending_request_file.insert(request_id, sender);
}
Command::RespondFile { file, channel } => {
self.swarm
.behaviour_mut()
.request_response
.send_response(channel, FileResponse(file))
.expect("Connection to peer to be still open.");
}
}
}
}
#[derive(NetworkBehaviour)]
#[behaviour(out_event = "ComposedEvent")]
struct ComposedBehaviour {
request_response: request_response::Behaviour<FileExchangeCodec>,
kademlia: Kademlia<MemoryStore>,
}
#[derive(Debug)]
enum ComposedEvent {
RequestResponse(request_response::Event<FileRequest, FileResponse>),
Kademlia(KademliaEvent),
}
impl From<request_response::Event<FileRequest, FileResponse>> for ComposedEvent {
fn from(event: request_response::Event<FileRequest, FileResponse>) -> Self {
ComposedEvent::RequestResponse(event)
}
}
impl From<KademliaEvent> for ComposedEvent {
fn from(event: KademliaEvent) -> Self {
ComposedEvent::Kademlia(event)
}
}
#[derive(Debug)]
enum Command {
StartListening {
addr: Multiaddr,
sender: oneshot::Sender<Result<(), Box<dyn Error + Send>>>,
},
Dial {
peer_id: PeerId,
peer_addr: Multiaddr,
sender: oneshot::Sender<Result<(), Box<dyn Error + Send>>>,
},
StartProviding {
file_name: String,
sender: oneshot::Sender<()>,
},
GetProviders {
file_name: String,
sender: oneshot::Sender<HashSet<PeerId>>,
},
RequestFile {
file_name: String,
peer: PeerId,
sender: oneshot::Sender<Result<Vec<u8>, Box<dyn Error + Send>>>,
},
RespondFile {
file: Vec<u8>,
channel: ResponseChannel<FileResponse>,
},
}
#[derive(Debug)]
pub(crate) enum Event {
InboundRequest {
request: String,
channel: ResponseChannel<FileResponse>,
},
}
// Simple file exchange protocol
#[derive(Debug, Clone)]
struct FileExchangeProtocol();
#[derive(Clone)]
struct FileExchangeCodec();
#[derive(Debug, Clone, PartialEq, Eq)]
struct FileRequest(String);
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct FileResponse(Vec<u8>);
impl ProtocolName for FileExchangeProtocol {
fn protocol_name(&self) -> &[u8] {
"/file-exchange/1".as_bytes()
}
}
#[async_trait]
impl request_response::Codec for FileExchangeCodec {
type Protocol = FileExchangeProtocol;
type Request = FileRequest;
type Response = FileResponse;
async fn read_request<T>(
&mut self,
_: &FileExchangeProtocol,
io: &mut T,
) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send,
{
let vec = read_length_prefixed(io, 1_000_000).await?;
if vec.is_empty() {
return Err(io::ErrorKind::UnexpectedEof.into());
}
Ok(FileRequest(String::from_utf8(vec).unwrap()))
}
async fn read_response<T>(
&mut self,
_: &FileExchangeProtocol,
io: &mut T,
) -> io::Result<Self::Response>
where
T: AsyncRead + Unpin + Send,
{
let vec = read_length_prefixed(io, 500_000_000).await?; // update transfer maximum
if vec.is_empty() {
return Err(io::ErrorKind::UnexpectedEof.into());
}
Ok(FileResponse(vec))
}
async fn write_request<T>(
&mut self,
_: &FileExchangeProtocol,
io: &mut T,
FileRequest(data): FileRequest,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
write_length_prefixed(io, data).await?;
io.close().await?;
Ok(())
}
async fn write_response<T>(
&mut self,
_: &FileExchangeProtocol,
io: &mut T,
FileResponse(data): FileResponse,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
write_length_prefixed(io, data).await?;
io.close().await?;
Ok(())
}
}