From b98b03eb7eae556b4cd9ff32f19fc23a961224e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveira?= Date: Wed, 1 Feb 2023 21:30:27 +0000 Subject: [PATCH] fix(swarm-derive): add bounds to `OutEvent` on generic fields (#3393) This PR isolates the bugfix for the `NetworkBehaviour` derive implementation for structures with generic fields. When `out_event` was not provided, the generated enum was missing the `NetworkBehaviour` impl constraint for the generic variants whilst using the generics for `::OutEvent`. Meanwhile I also found that the generated `poll` function `loop`s the sub behaviours and either `return`'s when `Poll::Ready` or `break`'s when `Poll::Pending`. This is a relict from when we still had `NetworkBehaviourEventProcess` which had added a branch within this loop that did not `return` but consume the event and `continue`. This trait was removed a while ago meaning this `loop` is no longer needed. --- swarm-derive/CHANGELOG.md | 6 ++ swarm-derive/src/lib.rs | 131 +++++++++++++++++++++++------------- swarm/tests/swarm_derive.rs | 38 +++++++++++ 3 files changed, 129 insertions(+), 46 deletions(-) diff --git a/swarm-derive/CHANGELOG.md b/swarm-derive/CHANGELOG.md index b7ea340b..09de2458 100644 --- a/swarm-derive/CHANGELOG.md +++ b/swarm-derive/CHANGELOG.md @@ -1,9 +1,14 @@ # 0.32.0 [unreleased] +- Fix `NetworkBehaviour` Derive macro for generic types when `out_event` was not provided. Previously the enum generated + didn't have the `NetworkBehaviour` impl constraints whilst using the generics for `::OutEvent`. + See [PR 3393]. + - Replace `NetworkBehaviour` Derive macro deprecated `inject_*` method implementations with the new `on_swarm_event` and `on_connection_handler_event`. See [PR 3011] and [PR 3264]. +[PR 3393]: https://github.com/libp2p/rust-libp2p/pull/3393 [PR 3011]: https://github.com/libp2p/rust-libp2p/pull/3011 [PR 3264]: https://github.com/libp2p/rust-libp2p/pull/3264 @@ -123,3 +128,4 @@ ambiguity. [PR 1681](https://github.com/libp2p/rust-libp2p/pull/1681). mechanism through `#[behaviour(event_process = false)]`. This is useful if users want to process all events while polling the swarm through `SwarmEvent::Behaviour`. + diff --git a/swarm-derive/src/lib.rs b/swarm-derive/src/lib.rs index aa2a141c..d4bd3b39 100644 --- a/swarm-derive/src/lib.rs +++ b/swarm-derive/src/lib.rs @@ -106,40 +106,81 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { } // User did not provide `OutEvent`. Generate it. None => { - let name: syn::Type = syn::parse_str(&(ast.ident.to_string() + "Event")).unwrap(); + let enum_name_str = ast.ident.to_string() + "Event"; + let enum_name: syn::Type = syn::parse_str(&enum_name_str).unwrap(); let definition = { - let fields = data_struct - .fields - .iter() - .map(|field| { - let variant: syn::Variant = syn::parse_str( - &field - .ident - .clone() - .expect( - "Fields of NetworkBehaviour implementation to be named.", - ) - .to_string() - .to_upper_camel_case(), - ) - .unwrap(); - let ty = &field.ty; - quote! {#variant(<#ty as #trait_to_impl>::OutEvent)} - }) - .collect::>(); + let fields = data_struct.fields.iter().map(|field| { + let variant: syn::Variant = syn::parse_str( + &field + .ident + .clone() + .expect("Fields of NetworkBehaviour implementation to be named.") + .to_string() + .to_upper_camel_case(), + ) + .unwrap(); + let ty = &field.ty; + (variant, ty) + }); + + let enum_variants = fields + .clone() + .map(|(variant, ty)| quote! {#variant(<#ty as #trait_to_impl>::OutEvent)}); + let visibility = &ast.vis; + let additional = fields + .clone() + .map(|(_variant, tp)| quote! { #tp : #trait_to_impl }) + .collect::>(); + + let additional_debug = fields + .clone() + .map(|(_variant, ty)| quote! { <#ty as #trait_to_impl>::OutEvent : ::core::fmt::Debug }) + .collect::>(); + + let where_clause = { + if let Some(where_clause) = where_clause { + if where_clause.predicates.trailing_punct() { + Some(quote! {#where_clause #(#additional),* }) + } else { + Some(quote! {#where_clause, #(#additional),*}) + } + } else if additional.is_empty() { + None + } else { + Some(quote! {where #(#additional),*}) + } + }; + + let where_clause_debug = where_clause + .as_ref() + .map(|where_clause| quote! {#where_clause, #(#additional_debug),*}); + + let match_variants = fields.map(|(variant, _ty)| variant); + let msg = format!("`NetworkBehaviour::OutEvent` produced by {name}."); + Some(quote! { - #[derive(::std::fmt::Debug)] - #visibility enum #name #impl_generics + #[doc = #msg] + #visibility enum #enum_name #ty_generics #where_clause { - #(#fields),* + #(#enum_variants),* + } + + impl #impl_generics ::core::fmt::Debug for #enum_name #ty_generics #where_clause_debug { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match &self { + #(#enum_name::#match_variants(event) => { + write!(f, "{}: {:?}", #enum_name_str, event) + }),* + } + } } }) }; let from_clauses = vec![]; - (name, definition, from_clauses) + (enum_name, definition, from_clauses) } } }; @@ -664,30 +705,28 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { } }; - Some(quote!{ - loop { - match #trait_to_impl::poll(&mut self.#field, cx, poll_params) { - #generate_event_match_arm - std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: provided_handler }) => { - return std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: #provided_handler_and_new_handlers }); - } - std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { peer_id, handler, event }) => { - return std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { - peer_id, - handler, - event: #wrapped_event, - }); - } - std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }) => { - return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }); - } - std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection }) => { - return std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection }); - } - std::task::Poll::Pending => break, + quote!{ + match #trait_to_impl::poll(&mut self.#field, cx, poll_params) { + #generate_event_match_arm + std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: provided_handler }) => { + return std::task::Poll::Ready(#network_behaviour_action::Dial { opts, handler: #provided_handler_and_new_handlers }); } + std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { peer_id, handler, event }) => { + return std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { + peer_id, + handler, + event: #wrapped_event, + }); + } + std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }) => { + return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }); + } + std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection }) => { + return std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection }); + } + std::task::Poll::Pending => {}, } - }) + } }); let out_event_reference = if out_event_definition.is_some() { diff --git a/swarm/tests/swarm_derive.rs b/swarm/tests/swarm_derive.rs index 0d3b116f..48eda804 100644 --- a/swarm/tests/swarm_derive.rs +++ b/swarm/tests/swarm_derive.rs @@ -308,6 +308,44 @@ fn with_either() { } } +#[test] +fn with_generics() { + #[allow(dead_code)] + #[derive(NetworkBehaviour)] + #[behaviour(prelude = "libp2p_swarm::derive_prelude")] + struct Foo { + a: A, + b: B, + } + + #[allow(dead_code)] + fn foo() { + require_net_behaviour::< + Foo< + libp2p_kad::Kademlia, + libp2p_ping::Behaviour, + >, + >(); + } +} + +#[test] +fn with_generics_mixed() { + #[allow(dead_code)] + #[derive(NetworkBehaviour)] + #[behaviour(prelude = "libp2p_swarm::derive_prelude")] + struct Foo { + a: A, + ping: libp2p_ping::Behaviour, + } + + #[allow(dead_code)] + fn foo() { + require_net_behaviour::>>( + ); + } +} + #[test] fn custom_event_with_either() { use either::Either;