From c0cad447c1b1742f59ca6de01d7b0def1c341156 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 4 Apr 2018 08:24:19 -0700 Subject: [PATCH] Initial support for closures This commit starts wasm-bindgen down the path of supporting closures. We discussed this at the recent Rust All-Hands but I ended up needing to pretty significantly scale back the ambitions of what closures are supported. This commit is just the initial support and provides only a small amount of support but will hopefully provide a good basis for future implementations. Specifically this commit adds support for passing `&Fn(...)` to an *imported function*, but nothing elese. The `&Fn` type can have any lifetime and the JS object is invalidated as soon as the import returns. The arguments and return value of `Fn` must currently implement the `WasmAbi` trait, aka they can't require any conversions like strings/types/etc. I'd like to soon expand this to `&mut FnMut` as well as `'static` closures that can be passed around for a long time in JS, but for now I'm putting that off until later. I'm not currently sure how to implement richer argument types, but hopefully that can be figured out at some point! --- crates/backend/src/ast.rs | 112 ++++++++++++++++++++++++---------- crates/backend/src/codegen.rs | 48 ++++++++------- crates/backend/src/literal.rs | 29 ++++++--- crates/cli-support/src/js.rs | 54 ++++++++++++++++ crates/cli-support/src/lib.rs | 1 + crates/shared/src/lib.rs | 4 +- src/convert.rs | 73 +++++++++++++++++++++- tests/closures.rs | 90 +++++++++++++++++++++++++++ 8 files changed, 347 insertions(+), 64 deletions(-) create mode 100644 tests/closures.rs diff --git a/crates/backend/src/ast.rs b/crates/backend/src/ast.rs index ece288fe..3b945c0b 100644 --- a/crates/backend/src/ast.rs +++ b/crates/backend/src/ast.rs @@ -81,10 +81,25 @@ pub struct Variant { pub value: u32, } -pub enum Type { - ByRef(syn::Type), - ByMutRef(syn::Type), - ByValue(syn::Type), +pub struct Type { + pub ty: syn::Type, + pub kind: TypeKind, + pub loc: TypeLocation, +} + +#[derive(Copy, Clone)] +pub enum TypeKind { + ByRef, + ByMutRef, + ByValue, +} + +#[derive(Copy, Clone)] +pub enum TypeLocation { + ImportArgument, + ImportRet, + ExportArgument, + ExportRet, } impl Program { @@ -197,6 +212,7 @@ impl Program { opts, method.vis, true, + false, ); self.exports.push(Export { class: Some(class), @@ -284,7 +300,15 @@ impl Program { pub fn push_foreign_fn(&mut self, f: syn::ForeignItemFn, opts: BindgenAttrs) -> ImportKind { let js_name = opts.js_name().unwrap_or(f.ident); - let mut wasm = Function::from_decl(js_name, f.decl, f.attrs, opts, f.vis, false).0; + let mut wasm = Function::from_decl( + js_name, + f.decl, + f.attrs, + opts, + f.vis, + false, + true, + ).0; if wasm.opts.catch() { // TODO: this assumes a whole bunch: // @@ -301,11 +325,7 @@ impl Program { let class = wasm.arguments .get(0) .expect("methods must have at least one argument"); - let class = match *class { - Type::ByRef(ref t) | Type::ByValue(ref t) => t, - Type::ByMutRef(_) => panic!("first method argument cannot be mutable ref"), - }; - let class_name = match *class { + let class_name = match class.ty { syn::Type::Path(syn::TypePath { qself: None, ref path, @@ -317,11 +337,11 @@ impl Program { ImportFunctionKind::Method { class: class_name.as_ref().to_string(), - ty: class.clone(), + ty: class.ty.clone(), } } else if wasm.opts.constructor() { let class = match wasm.ret { - Some(Type::ByValue(ref t)) => t, + Some(Type { ref ty, kind: TypeKind::ByValue, .. }) => ty, _ => panic!("constructor returns must be bare types"), }; let class_name = match *class { @@ -416,7 +436,15 @@ impl Function { panic!("can only bindgen safe functions"); } - Function::from_decl(input.ident, input.decl, input.attrs, opts, input.vis, false).0 + Function::from_decl( + input.ident, + input.decl, + input.attrs, + opts, + input.vis, + false, + false, + ).0 } pub fn from_decl( @@ -426,6 +454,7 @@ impl Function { opts: BindgenAttrs, vis: syn::Visibility, allow_self: bool, + import: bool, ) -> (Function, Option) { if decl.variadic.is_some() { panic!("can't bindgen variadic functions") @@ -449,12 +478,24 @@ impl Function { } _ => panic!("arguments cannot be `self` or ignored"), }) - .map(|arg| Type::from(&arg.ty)) + .map(|arg| { + Type::from(&arg.ty, if import { + TypeLocation::ImportArgument + } else { + TypeLocation::ExportArgument + }) + }) .collect::>(); let ret = match decl.output { syn::ReturnType::Default => None, - syn::ReturnType::Type(_, ref t) => Some(Type::from(t)), + syn::ReturnType::Type(_, ref t) => { + Some(Type::from(t, if import { + TypeLocation::ImportRet + } else { + TypeLocation::ExportRet + })) + } }; ( @@ -486,19 +527,6 @@ pub fn extract_path_ident(path: &syn::Path) -> Option { path.segments.first().map(|v| v.value().ident) } -impl Type { - pub fn from(ty: &syn::Type) -> Type { - if let syn::Type::Reference(ref r) = *ty { - return if r.mutability.is_some() { - Type::ByMutRef((*r.elem).clone()) - } else { - Type::ByRef((*r.elem).clone()) - } - } - Type::ByValue(ty.clone()) - } -} - impl Export { pub fn rust_symbol(&self) -> syn::Ident { let mut generated_name = format!("__wasm_bindgen_generated"); @@ -540,6 +568,22 @@ impl Struct { } } +impl Type { + pub fn from(ty: &syn::Type, loc: TypeLocation) -> Type { + let (ty, kind) = match *ty { + syn::Type::Reference(ref r) => { + if r.mutability.is_some() { + ((*r.elem).clone(), TypeKind::ByMutRef) + } else { + ((*r.elem).clone(), TypeKind::ByRef) + } + } + _ => (ty.clone(), TypeKind::ByValue), + }; + Type { loc, ty, kind } + } +} + #[derive(Default)] pub struct BindgenAttrs { attrs: Vec, @@ -719,12 +763,12 @@ impl syn::synom::Synom for BindgenAttr { } fn extract_first_ty_param(ty: Option<&Type>) -> Option> { - let ty = match ty { + let t = match ty { Some(t) => t, None => return Some(None), }; - let ty = match *ty { - Type::ByValue(ref t) => t, + let ty = match *t { + Type { ref ty, kind: TypeKind::ByValue, .. } => ty, _ => return None, }; let path = match *ty { @@ -747,7 +791,11 @@ fn extract_first_ty_param(ty: Option<&Type>) -> Option> { syn::Type::Tuple(ref t) if t.elems.len() == 0 => return Some(None), _ => {} } - Some(Some(Type::from(ty))) + Some(Some(Type { + ty: ty.clone(), + kind: TypeKind::ByValue, + loc: t.loc, + })) } fn term<'a>(cursor: syn::buffer::Cursor<'a>, name: &str) -> syn::synom::PResult<'a, ()> { diff --git a/crates/backend/src/codegen.rs b/crates/backend/src/codegen.rs index 8ea9e1c6..9841c006 100644 --- a/crates/backend/src/codegen.rs +++ b/crates/backend/src/codegen.rs @@ -210,8 +210,9 @@ impl ToTokens for ast::Export { for (i, ty) in self.function.arguments.iter().enumerate() { let i = i + offset; let ident = syn::Ident::from(format!("arg{}", i)); - match *ty { - ast::Type::ByValue(ref t) => { + let t = &ty.ty; + match ty.kind { + ast::TypeKind::ByValue => { args.push(quote! { #ident: <#t as ::wasm_bindgen::convert::WasmBoundary>::Abi }); @@ -222,25 +223,25 @@ impl ToTokens for ast::Export { }; }); } - ast::Type::ByRef(ref ty) => { + ast::TypeKind::ByRef => { args.push(quote! { - #ident: <#ty as ::wasm_bindgen::convert::FromRefWasmBoundary>::Abi + #ident: <#t as ::wasm_bindgen::convert::FromRefWasmBoundary>::Abi }); arg_conversions.push(quote! { let #ident = unsafe { - <#ty as ::wasm_bindgen::convert::FromRefWasmBoundary> + <#t as ::wasm_bindgen::convert::FromRefWasmBoundary> ::from_abi_ref(#ident, &mut __stack) }; let #ident = &*#ident; }); } - ast::Type::ByMutRef(ref ty) => { + ast::TypeKind::ByMutRef => { args.push(quote! { - #ident: <#ty as ::wasm_bindgen::convert::FromRefMutWasmBoundary>::Abi + #ident: <#t as ::wasm_bindgen::convert::FromRefMutWasmBoundary>::Abi }); arg_conversions.push(quote! { let mut #ident = unsafe { - <#ty as ::wasm_bindgen::convert::FromRefMutWasmBoundary> + <#t as ::wasm_bindgen::convert::FromRefMutWasmBoundary> ::from_abi_ref_mut(#ident, &mut __stack) }; let #ident = &mut *#ident; @@ -252,19 +253,19 @@ impl ToTokens for ast::Export { let ret_ty; let convert_ret; match self.function.ret { - Some(ast::Type::ByValue(ref t)) => { + Some(ast::Type { ref ty, kind: ast::TypeKind::ByValue, .. }) => { ret_ty = quote! { - -> <#t as ::wasm_bindgen::convert::WasmBoundary>::Abi + -> <#ty as ::wasm_bindgen::convert::WasmBoundary>::Abi }; convert_ret = quote! { - <#t as ::wasm_bindgen::convert::WasmBoundary> + <#ty as ::wasm_bindgen::convert::WasmBoundary> ::into_abi(#ret, &mut unsafe { ::wasm_bindgen::convert::GlobalStack::new() }) }; } - Some(ast::Type::ByMutRef(_)) - | Some(ast::Type::ByRef(_)) => { + Some(ast::Type { kind: ast::TypeKind::ByMutRef, .. }) | + Some(ast::Type { kind: ast::TypeKind::ByRef, .. }) => { panic!("can't return a borrowed ref"); } None => { @@ -432,8 +433,9 @@ impl ToTokens for ast::ImportFunction { }); for (i, (ty, name)) in self.function.arguments.iter().zip(names).enumerate() { - match *ty { - ast::Type::ByValue(ref t) => { + let t = &ty.ty; + match ty.kind { + ast::TypeKind::ByValue => { abi_argument_names.push(name); abi_arguments.push(quote! { #name: <#t as ::wasm_bindgen::convert::WasmBoundary>::Abi @@ -448,8 +450,8 @@ impl ToTokens for ast::ImportFunction { ::into_abi(#var, &mut __stack); }); } - ast::Type::ByMutRef(_) => panic!("urgh mut"), - ast::Type::ByRef(ref t) => { + ast::TypeKind::ByMutRef => panic!("urgh mut"), + ast::TypeKind::ByRef => { abi_argument_names.push(name); abi_arguments.push(quote! { #name: u32 }); let var = if i == 0 && is_method { @@ -467,20 +469,22 @@ impl ToTokens for ast::ImportFunction { let abi_ret; let mut convert_ret; match self.function.ret { - Some(ast::Type::ByValue(ref t)) => { + Some(ast::Type { ref ty, kind: ast::TypeKind::ByValue, .. }) => { abi_ret = quote! { - <#t as ::wasm_bindgen::convert::WasmBoundary>::Abi + <#ty as ::wasm_bindgen::convert::WasmBoundary>::Abi }; convert_ret = quote! { - <#t as ::wasm_bindgen::convert::WasmBoundary> + <#ty as ::wasm_bindgen::convert::WasmBoundary> ::from_abi( #ret_ident, &mut ::wasm_bindgen::convert::GlobalStack::new(), ) }; } - Some(ast::Type::ByRef(_)) - | Some(ast::Type::ByMutRef(_)) => panic!("can't return a borrowed ref"), + Some(ast::Type { kind: ast::TypeKind::ByRef, .. }) | + Some(ast::Type { kind: ast::TypeKind::ByMutRef, .. }) => { + panic!("can't return a borrowed ref") + } None => { abi_ret = quote! { () }; convert_ret = quote! { () }; diff --git a/crates/backend/src/literal.rs b/crates/backend/src/literal.rs index 992577b4..24e6d06e 100644 --- a/crates/backend/src/literal.rs +++ b/crates/backend/src/literal.rs @@ -139,18 +139,31 @@ impl Literal for ast::Function { impl Literal for ast::Type { fn literal(&self, a: &mut LiteralBuilder) { - match *self { - ast::Type::ByValue(ref t) => { + let t = &self.ty; + match self.kind { + ast::TypeKind::ByValue => { a.as_char(quote! { <#t as ::wasm_bindgen::convert::WasmBoundary>::DESCRIPTOR }); } - ast::Type::ByRef(ref ty) | ast::Type::ByMutRef(ref ty) => { - // TODO: this assumes `ToRef*` and `FromRef*` use the same - // descriptor. - a.as_char(quote! { - <#ty as ::wasm_bindgen::convert::FromRefWasmBoundary>::DESCRIPTOR - }); + ast::TypeKind::ByRef | + ast::TypeKind::ByMutRef => { + match self.loc { + ast::TypeLocation::ImportArgument | + ast::TypeLocation::ExportRet => { + a.as_char(quote! { + <#t as ::wasm_bindgen::convert::ToRefWasmBoundary> + ::DESCRIPTOR + }); + } + ast::TypeLocation::ImportRet | + ast::TypeLocation::ExportArgument => { + a.as_char(quote! { + <#t as ::wasm_bindgen::convert::FromRefWasmBoundary> + ::DESCRIPTOR + }); + } + } } } } diff --git a/crates/cli-support/src/js.rs b/crates/cli-support/src/js.rs index 9a48df56..3a30d032 100644 --- a/crates/cli-support/src/js.rs +++ b/crates/cli-support/src/js.rs @@ -19,6 +19,7 @@ pub struct Context<'a> { pub custom_type_names: HashMap, pub imported_names: HashSet, pub exported_classes: HashMap, + pub function_table_needed: bool, } #[derive(Default)] @@ -246,6 +247,7 @@ impl<'a> Context<'a> { ); self.unexport_unused_internal_exports(); + self.export_table(); (js, self.typescript.clone()) } @@ -313,6 +315,22 @@ impl<'a> Context<'a> { } } + fn export_table(&mut self) { + if !self.function_table_needed { + return + } + for section in self.module.sections_mut() { + let exports = match *section { + Section::Export(ref mut s) => s, + _ => continue, + }; + let entry = ExportEntry::new("__wbg_function_table".to_string(), + Internal::Table(0)); + exports.entries_mut().push(entry); + break + } + } + fn rewrite_imports(&mut self, module_name: &str) { for (name, contents) in self._rewrite_imports(module_name) { self.export(&name, &contents); @@ -1404,6 +1422,7 @@ impl<'a, 'b> SubContext<'a, 'b> { let mut abi_args = Vec::new(); let mut extra = String::new(); + let mut finally = String::new(); let mut next_global = 0; for (i, arg) in import.function.arguments.iter().enumerate() { @@ -1419,6 +1438,30 @@ impl<'a, 'b> SubContext<'a, 'b> { self.cx.expose_get_object(); format!("getObject(arg{})", i) } + shared::TYPE_STACK_FUNC0 | + shared::TYPE_STACK_FUNC1 => { + let nargs = *arg - shared::TYPE_STACK_FUNC0; + let args = (0..nargs) + .map(|i| format!("arg{}", i)) + .collect::>() + .join(", "); + self.cx.expose_get_global_argument(); + self.cx.function_table_needed = true; + let sep = if nargs == 0 {""} else {","}; + extra.push_str(&format!(" + let cb{0} = function({args}) {{ + return this.f(this.a, this.b {sep} {args}); + }}; + cb{0}.f = wasm.__wbg_function_table.get(arg{0}); + cb{0}.a = getGlobalArgument({next_global}); + cb{0}.b = getGlobalArgument({next_global} + 1); + ", i, next_global = next_global, args = args, sep = sep)); + next_global += 2; + finally.push_str(&format!(" + cb{0}.a = cb{0}.b = 0; + ", i)); + format!("cb{0}.bind(cb{0})", i) + } other => { match VectorType::from(other) { Some(ty) => { @@ -1585,6 +1628,17 @@ impl<'a, 'b> SubContext<'a, 'b> { } else { invoc }; + let invoc = if finally.len() > 0 { + format!(" + try {{ + {} + }} finally {{ + {} + }} + ", invoc, finally) + } else { + invoc + }; dst.push_str(&abi_args.join(", ")); dst.push_str(") {\n"); diff --git a/crates/cli-support/src/lib.rs b/crates/cli-support/src/lib.rs index ad0e3b62..a8015d40 100644 --- a/crates/cli-support/src/lib.rs +++ b/crates/cli-support/src/lib.rs @@ -94,6 +94,7 @@ impl Bindgen { exported_classes: Default::default(), config: &self, module: &mut module, + function_table_needed: false, }; for program in programs.iter() { cx.add_custom_type_names(program); diff --git a/crates/shared/src/lib.rs b/crates/shared/src/lib.rs index 30126d9f..87b56c6a 100644 --- a/crates/shared/src/lib.rs +++ b/crates/shared/src/lib.rs @@ -146,8 +146,10 @@ pub const TYPE_SLICE_F64: u32 = 20; pub const TYPE_VECTOR_F64: u32 = 21; pub const TYPE_JS_OWNED: u32 = 22; pub const TYPE_JS_REF: u32 = 23; +pub const TYPE_STACK_FUNC0: u32 = 24; +pub const TYPE_STACK_FUNC1: u32 = 25; -pub const TYPE_CUSTOM_START: u32 = 24; +pub const TYPE_CUSTOM_START: u32 = 26; pub const TYPE_CUSTOM_REF_FLAG: u32 = 1; pub fn name_to_descriptor(name: &str) -> u32 { diff --git a/src/convert.rs b/src/convert.rs index e5069247..07866e6d 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut}; use std::slice; use std::str; -use super::JsValue; +use {JsValue, throw}; #[derive(PartialEq, Eq, Copy, Clone)] pub struct Descriptor { @@ -21,6 +21,9 @@ pub const DESCRIPTOR_BOOLEAN: Descriptor = Descriptor { __x: *b" 5", }; pub const DESCRIPTOR_JS_OWNED: Descriptor = Descriptor { __x: *b" 22", }; pub const DESCRIPTOR_JS_REF: Descriptor = Descriptor { __x: *b" 23", }; +pub const DESCRIPTOR_STACK_FUNC0: Descriptor = Descriptor { __x: *b" 24", }; +pub const DESCRIPTOR_STACK_FUNC1: Descriptor = Descriptor { __x: *b" 25", }; + pub trait WasmBoundary { type Abi: WasmAbi; const DESCRIPTOR: Descriptor; @@ -337,3 +340,71 @@ impl Stack for GlobalStack { pub unsafe extern fn __wbindgen_global_argument_ptr() -> *mut u32 { GLOBAL_STACK.as_mut_ptr() } + +macro_rules! stack_closures { + ($( + ($($var:ident)*) => $descriptor:ident + )*) => ($( + impl<'a, $($var,)* R> ToRefWasmBoundary for Fn($($var),*) -> R + 'a + where $($var: WasmAbi,)* + R: WasmAbi + { + type Abi = u32; + const DESCRIPTOR: Descriptor = $descriptor; + + fn to_abi_ref(&self, extra: &mut Stack) -> u32 { + #[allow(non_snake_case)] + unsafe extern fn invoke<$($var,)* R>( + a: usize, + b: usize, + $($var: $var),* + ) -> R { + if a == 0 { + throw("stack closure has been destroyed already"); + } + let f: &Fn($($var),*) -> R = mem::transmute((a, b)); + f($($var),*) + } + unsafe { + let (a, b): (usize, usize) = mem::transmute(self); + extra.push(a as u32); + extra.push(b as u32); + invoke::<$($var,)* R> as u32 + } + } + } + + impl<'a, $($var,)*> ToRefWasmBoundary for Fn($($var),*) + 'a + where $($var: WasmAbi,)* + { + type Abi = u32; + const DESCRIPTOR: Descriptor = $descriptor; + + fn to_abi_ref(&self, extra: &mut Stack) -> u32 { + #[allow(non_snake_case)] + unsafe extern fn invoke<$($var,)* >( + a: usize, + b: usize, + $($var: $var),* + ) { + if a == 0 { + throw("stack closure has been destroyed already"); + } + let f: &Fn($($var),*) = mem::transmute((a, b)); + f($($var),*) + } + unsafe { + let (a, b): (usize, usize) = mem::transmute(self); + extra.push(a as u32); + extra.push(b as u32); + invoke::<$($var,)*> as u32 + } + } + } + )*) +} + +stack_closures! { + () => DESCRIPTOR_STACK_FUNC0 + (A) => DESCRIPTOR_STACK_FUNC1 +} diff --git a/tests/closures.rs b/tests/closures.rs new file mode 100644 index 00000000..67bd867a --- /dev/null +++ b/tests/closures.rs @@ -0,0 +1,90 @@ +extern crate test_support; + +#[test] +fn works() { + test_support::project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use std::cell::Cell; + use wasm_bindgen::prelude::*; + + #[wasm_bindgen(module = "./test")] + extern { + fn call(a: &Fn()); + fn thread(a: &Fn(u32) -> u32) -> u32; + } + + #[wasm_bindgen] + pub fn run() { + let a = Cell::new(false); + call(&|| a.set(true)); + assert!(a.get()); + + assert_eq!(thread(&|a| a + 1), 3); + } + "#) + .file("test.ts", r#" + import { run } from "./out"; + + export function call(a: any) { + a(); + } + + export function thread(a: any) { + return a(2); + } + + export function test() { + run(); + } + "#) + .test(); +} + +#[test] +fn cannot_reuse() { + test_support::project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use std::cell::Cell; + use wasm_bindgen::prelude::*; + + #[wasm_bindgen(module = "./test")] + extern { + fn call(a: &Fn()); + #[wasm_bindgen(catch)] + fn call_again() -> Result<(), JsValue>; + } + + #[wasm_bindgen] + pub fn run() { + call(&|| {}); + assert!(call_again().is_err()); + } + "#) + .file("test.ts", r#" + import { run } from "./out"; + + let CACHE: any = null; + + export function call(a: any) { + CACHE = a; + } + + export function call_again() { + CACHE(); + } + + export function test() { + run(); + } + "#) + .test(); +} +