fix(swarm-derive): add bounds to OutEvent on generic fields (#3393)

This PR isolates the bugfix for the `NetworkBehaviour` derive implementation for structures with generic fields. When `out_event` was not provided, the generated enum was missing the `NetworkBehaviour` impl constraint for the generic variants whilst using the generics for `<Generic>::OutEvent`.

Meanwhile I also found that the generated `poll` function `loop`s the sub behaviours and either `return`'s when `Poll::Ready` or `break`'s when `Poll::Pending`. This is a relict from when we still had `NetworkBehaviourEventProcess` which had added a branch within this loop that did not `return` but consume the event and `continue`. This trait was removed a while ago meaning this `loop` is no longer needed.
This commit is contained in:
João Oliveira
2023-02-01 21:30:27 +00:00
committed by GitHub
parent 063aab5909
commit b98b03eb7e
3 changed files with 129 additions and 46 deletions

View File

@ -1,9 +1,14 @@
# 0.32.0 [unreleased]
- Fix `NetworkBehaviour` Derive macro for generic types when `out_event` was not provided. Previously the enum generated
didn't have the `NetworkBehaviour` impl constraints whilst using the generics for `<Generic>::OutEvent`.
See [PR 3393].
- Replace `NetworkBehaviour` Derive macro deprecated `inject_*` method implementations
with the new `on_swarm_event` and `on_connection_handler_event`.
See [PR 3011] and [PR 3264].
[PR 3393]: https://github.com/libp2p/rust-libp2p/pull/3393
[PR 3011]: https://github.com/libp2p/rust-libp2p/pull/3011
[PR 3264]: https://github.com/libp2p/rust-libp2p/pull/3264
@ -123,3 +128,4 @@ ambiguity. [PR 1681](https://github.com/libp2p/rust-libp2p/pull/1681).
mechanism through `#[behaviour(event_process = false)]`. This is
useful if users want to process all events while polling the
swarm through `SwarmEvent::Behaviour`.

View File

@ -106,40 +106,81 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream {
}
// User did not provide `OutEvent`. Generate it.
None => {
let name: syn::Type = syn::parse_str(&(ast.ident.to_string() + "Event")).unwrap();
let enum_name_str = ast.ident.to_string() + "Event";
let enum_name: syn::Type = syn::parse_str(&enum_name_str).unwrap();
let definition = {
let fields = data_struct
.fields
.iter()
.map(|field| {
let variant: syn::Variant = syn::parse_str(
&field
.ident
.clone()
.expect(
"Fields of NetworkBehaviour implementation to be named.",
)
.to_string()
.to_upper_camel_case(),
)
.unwrap();
let ty = &field.ty;
quote! {#variant(<#ty as #trait_to_impl>::OutEvent)}
})
.collect::<Vec<_>>();
let fields = data_struct.fields.iter().map(|field| {
let variant: syn::Variant = syn::parse_str(
&field
.ident
.clone()
.expect("Fields of NetworkBehaviour implementation to be named.")
.to_string()
.to_upper_camel_case(),
)
.unwrap();
let ty = &field.ty;
(variant, ty)
});
let enum_variants = fields
.clone()
.map(|(variant, ty)| quote! {#variant(<#ty as #trait_to_impl>::OutEvent)});
let visibility = &ast.vis;
let additional = fields
.clone()
.map(|(_variant, tp)| quote! { #tp : #trait_to_impl })
.collect::<Vec<_>>();
let additional_debug = fields
.clone()
.map(|(_variant, ty)| quote! { <#ty as #trait_to_impl>::OutEvent : ::core::fmt::Debug })
.collect::<Vec<_>>();
let where_clause = {
if let Some(where_clause) = where_clause {
if where_clause.predicates.trailing_punct() {
Some(quote! {#where_clause #(#additional),* })
} else {
Some(quote! {#where_clause, #(#additional),*})
}
} else if additional.is_empty() {
None
} else {
Some(quote! {where #(#additional),*})
}
};
let where_clause_debug = where_clause
.as_ref()
.map(|where_clause| quote! {#where_clause, #(#additional_debug),*});
let match_variants = fields.map(|(variant, _ty)| variant);
let msg = format!("`NetworkBehaviour::OutEvent` produced by {name}.");
Some(quote! {
#[derive(::std::fmt::Debug)]
#visibility enum #name #impl_generics
#[doc = #msg]
#visibility enum #enum_name #ty_generics
#where_clause
{
#(#fields),*
#(#enum_variants),*
}
impl #impl_generics ::core::fmt::Debug for #enum_name #ty_generics #where_clause_debug {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match &self {
#(#enum_name::#match_variants(event) => {
write!(f, "{}: {:?}", #enum_name_str, event)
}),*
}
}
}
})
};
let from_clauses = vec![];
(name, definition, from_clauses)
(enum_name, definition, from_clauses)
}
}
};
@ -664,30 +705,28 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream {
}
};
Some(quote!{
loop {
match #trait_to_impl::poll(&mut self.#field, cx, poll_params) {
#generate_event_match_arm
std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: provided_handler }) => {
return std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: #provided_handler_and_new_handlers });
}
std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { peer_id, handler, event }) => {
return std::task::Poll::Ready(#network_behaviour_action::NotifyHandler {
peer_id,
handler,
event: #wrapped_event,
});
}
std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }) => {
return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score });
}
std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection }) => {
return std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection });
}
std::task::Poll::Pending => break,
quote!{
match #trait_to_impl::poll(&mut self.#field, cx, poll_params) {
#generate_event_match_arm
std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: provided_handler }) => {
return std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: #provided_handler_and_new_handlers });
}
std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { peer_id, handler, event }) => {
return std::task::Poll::Ready(#network_behaviour_action::NotifyHandler {
peer_id,
handler,
event: #wrapped_event,
});
}
std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }) => {
return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score });
}
std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection }) => {
return std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection });
}
std::task::Poll::Pending => {},
}
})
}
});
let out_event_reference = if out_event_definition.is_some() {

View File

@ -308,6 +308,44 @@ fn with_either() {
}
}
#[test]
fn with_generics() {
#[allow(dead_code)]
#[derive(NetworkBehaviour)]
#[behaviour(prelude = "libp2p_swarm::derive_prelude")]
struct Foo<A, B> {
a: A,
b: B,
}
#[allow(dead_code)]
fn foo() {
require_net_behaviour::<
Foo<
libp2p_kad::Kademlia<libp2p_kad::record::store::MemoryStore>,
libp2p_ping::Behaviour,
>,
>();
}
}
#[test]
fn with_generics_mixed() {
#[allow(dead_code)]
#[derive(NetworkBehaviour)]
#[behaviour(prelude = "libp2p_swarm::derive_prelude")]
struct Foo<A> {
a: A,
ping: libp2p_ping::Behaviour,
}
#[allow(dead_code)]
fn foo() {
require_net_behaviour::<Foo<libp2p_kad::Kademlia<libp2p_kad::record::store::MemoryStore>>>(
);
}
}
#[test]
fn custom_event_with_either() {
use either::Either;