diff --git a/protocols/request-response/src/handler/protocol.rs b/protocols/request-response/src/handler/protocol.rs index bff57058..fb44c8e9 100644 --- a/protocols/request-response/src/handler/protocol.rs +++ b/protocols/request-response/src/handler/protocol.rs @@ -106,6 +106,7 @@ where let write = self.codec.write_response(&protocol, &mut io, response); write.await?; } else { + io.close().await?; return Ok(false) } } diff --git a/protocols/request-response/src/lib.rs b/protocols/request-response/src/lib.rs index d92409eb..dae0c11c 100644 --- a/protocols/request-response/src/lib.rs +++ b/protocols/request-response/src/lib.rs @@ -163,7 +163,7 @@ pub enum RequestResponseEvent /// Possible failures occurring in the context of sending /// an outbound request and receiving the response. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum OutboundFailure { /// The request could not be sent because a dialing attempt failed. DialFailure, @@ -183,7 +183,7 @@ pub enum OutboundFailure { /// Possible failures occurring in the context of receiving an /// inbound request and sending a response. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum InboundFailure { /// The inbound request timed out, either while reading the /// incoming request or before a response is sent, e.g. if diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 9433f67f..7c07d6a1 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -47,8 +47,8 @@ fn is_response_outbound() { let cfg = RequestResponseConfig::default(); let (peer1_id, trans) = mk_transport(); - let ping_proto1 = RequestResponse::new(PingCodec(), protocols.clone(), cfg.clone()); - let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); + let ping_proto1 = RequestResponse::new(PingCodec(), protocols, cfg); + let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); let request_id1 = swarm1.send_request(&offline_peer, ping.clone()); @@ -60,7 +60,7 @@ fn is_response_outbound() { e => panic!("Peer: Unexpected event: {:?}", e), } - let request_id2 = swarm1.send_request(&offline_peer, ping.clone()); + let request_id2 = swarm1.send_request(&offline_peer, ping); assert!(!swarm1.is_pending_outbound(&offline_peer, &request_id1)); assert!(swarm1.is_pending_outbound(&offline_peer, &request_id2)); @@ -77,11 +77,11 @@ fn ping_protocol() { let (peer1_id, trans) = mk_transport(); let ping_proto1 = RequestResponse::new(PingCodec(), protocols.clone(), cfg.clone()); - let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); + let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); let (peer2_id, trans) = mk_transport(); let ping_proto2 = RequestResponse::new(PingCodec(), protocols, cfg); - let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id.clone()); + let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id); let (mut tx, mut rx) = mpsc::channel::(1); @@ -157,17 +157,17 @@ fn emits_inbound_connection_closed_failure() { let (peer1_id, trans) = mk_transport(); let ping_proto1 = RequestResponse::new(PingCodec(), protocols.clone(), cfg.clone()); - let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); + let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); let (peer2_id, trans) = mk_transport(); let ping_proto2 = RequestResponse::new(PingCodec(), protocols, cfg); - let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id.clone()); + let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id); let addr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); Swarm::listen_on(&mut swarm1, addr).unwrap(); futures::executor::block_on(async move { - while let Some(_) = swarm1.next().now_or_never() {} + while swarm1.next().now_or_never().is_some() {} let addr1 = Swarm::listeners(&swarm1).next().unwrap(); swarm2.add_address(&peer1_id, addr1.clone()); @@ -201,6 +201,64 @@ fn emits_inbound_connection_closed_failure() { }); } +/// We expect the substream to be properly closed when response channel is dropped. +/// Since the ping protocol used here expects a response, the sender considers this +/// early close as a protocol violation which results in the connection being closed. +/// If the substream were not properly closed when dropped, the sender would instead +/// run into a timeout waiting for the response. +#[test] +fn emits_inbound_connection_closed_if_channel_is_dropped() { + let ping = Ping("ping".to_string().into_bytes()); + + let protocols = iter::once((PingProtocol(), ProtocolSupport::Full)); + let cfg = RequestResponseConfig::default(); + + let (peer1_id, trans) = mk_transport(); + let ping_proto1 = RequestResponse::new(PingCodec(), protocols.clone(), cfg.clone()); + let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); + + let (peer2_id, trans) = mk_transport(); + let ping_proto2 = RequestResponse::new(PingCodec(), protocols, cfg); + let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id); + + let addr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); + Swarm::listen_on(&mut swarm1, addr).unwrap(); + + futures::executor::block_on(async move { + while swarm1.next().now_or_never().is_some() {} + let addr1 = Swarm::listeners(&swarm1).next().unwrap(); + + swarm2.add_address(&peer1_id, addr1.clone()); + swarm2.send_request(&peer1_id, ping.clone()); + + // Wait for swarm 1 to receive request by swarm 2. + let event = loop { + futures::select!( + event = swarm1.next().fuse() => if let RequestResponseEvent::Message { + peer, + message: RequestResponseMessage::Request { request, channel, .. } + } = event { + assert_eq!(&request, &ping); + assert_eq!(&peer, &peer2_id); + + drop(channel); + continue; + }, + event = swarm2.next().fuse() => { + break event; + }, + ) + }; + + let error = match event { + RequestResponseEvent::OutboundFailure { error, .. } => error, + e => panic!("unexpected event from peer 2: {:?}", e) + }; + + assert_eq!(error, OutboundFailure::ConnectionClosed); + }); +} + #[test] fn ping_protocol_throttled() { let ping = Ping("ping".to_string().into_bytes()); @@ -211,11 +269,11 @@ fn ping_protocol_throttled() { let (peer1_id, trans) = mk_transport(); let ping_proto1 = RequestResponse::throttled(PingCodec(), protocols.clone(), cfg.clone()); - let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); + let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); let (peer2_id, trans) = mk_transport(); let ping_proto2 = RequestResponse::throttled(PingCodec(), protocols, cfg); - let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id.clone()); + let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id); let (mut tx, mut rx) = mpsc::channel::(1);