diff --git a/README.md b/README.md index 64037eaf..46a1b7d4 100644 --- a/README.md +++ b/README.md @@ -407,7 +407,7 @@ you can do are: ```rust #[wasm_bindgen] extern { - fn foo(a: &Fn()); // must be `Fn`, not `FnMut` + fn foo(a: &Fn()); // could also be `&mut FnMut()` } ``` @@ -452,7 +452,7 @@ returns, and the validity of the JS closure is tied to the lifetime of the `Closure` in Rust. Once `Closure` is dropped it will deallocate its internal memory and invalidate the corresponding JS function. -Unlike stack closures a `Closure` supports `FnMut`: +Like stack closures a `Closure` also supports `FnMut`: ```rust use wasm_bindgen::prelude::*; diff --git a/crates/backend/src/codegen.rs b/crates/backend/src/codegen.rs index 297c4b84..cfce2a6f 100644 --- a/crates/backend/src/codegen.rs +++ b/crates/backend/src/codegen.rs @@ -495,7 +495,14 @@ impl ToTokens for ast::ImportFunction { ::into_abi(#var, &mut __stack); }); } - ast::TypeKind::ByMutRef => panic!("urgh mut"), + ast::TypeKind::ByMutRef => { + abi_argument_names.push(name); + abi_arguments.push(quote! { #name: u32 }); + arg_conversions.push(quote! { + let #name = <#t as ::wasm_bindgen::convert::ToRefMutWasmBoundary> + ::to_abi_ref_mut(#name, &mut __stack); + }); + } ast::TypeKind::ByRef => { abi_argument_names.push(name); abi_arguments.push(quote! { #name: u32 }); diff --git a/crates/cli-support/src/descriptor.rs b/crates/cli-support/src/descriptor.rs index 5d590cc8..5cc4d7ca 100644 --- a/crates/cli-support/src/descriptor.rs +++ b/crates/cli-support/src/descriptor.rs @@ -212,13 +212,14 @@ impl Descriptor { } } - pub fn stack_closure(&self) -> Option<&Function> { - let inner = match *self { - Descriptor::Ref(ref d) => &**d, + pub fn stack_closure(&self) -> Option<(&Function, bool)> { + let (inner, mutable) = match *self { + Descriptor::Ref(ref d) => (&**d, false), + Descriptor::RefMut(ref d) => (&**d, true), _ => return None, }; match *inner { - Descriptor::Function(ref f) => Some(f), + Descriptor::Function(ref f) => Some((f, mutable)), _ => None, } } diff --git a/crates/cli-support/src/js.rs b/crates/cli-support/src/js.rs index 6b00b6b6..8352295a 100644 --- a/crates/cli-support/src/js.rs +++ b/crates/cli-support/src/js.rs @@ -1570,7 +1570,7 @@ impl<'a, 'b> SubContext<'a, 'b> { continue } - if let Some(f) = arg.stack_closure() { + if let Some((f, mutable)) = arg.stack_closure() { let args = (0..f.arguments.len()) .map(|i| format!("arg{}", i)) .collect::>() @@ -1578,14 +1578,25 @@ impl<'a, 'b> SubContext<'a, 'b> { self.cx.expose_get_global_argument(); self.cx.function_table_needed = true; let sep = if f.arguments.len() == 0 {""} else {","}; + let body = if mutable { + format!(" + let a = this.a; + this.a = 0; + try {{ + return this.f(a, this.b {} {}); + }} finally {{ + this.a = a; + }} + ", sep, args) + } else { + format!("return this.f(this.a, this.b {} {});", sep, args) + }; extra.push_str(&format!(" - let cb{0} = function({args}) {{ - return this.f(this.a, this.b {sep} {args}); - }}; + let cb{0} = function({args}) {{ {body} }}; cb{0}.f = wasm.__wbg_function_table.get(arg{0}); cb{0}.a = getGlobalArgument({next_global}); cb{0}.b = getGlobalArgument({next_global} + 1); - ", i, next_global = next_global, args = args, sep = sep)); + ", i, next_global = next_global, body = body, args = args)); next_global += 2; finally.push_str(&format!(" cb{0}.a = cb{0}.b = 0; diff --git a/src/convert.rs b/src/convert.rs index 6eae11dd..0f88f71f 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -42,6 +42,12 @@ pub trait ToRefWasmBoundary: WasmDescribe { fn to_abi_ref(&self, extra: &mut Stack) -> u32; } +pub trait ToRefMutWasmBoundary: WasmDescribe { + type Abi: WasmAbi; + + fn to_abi_ref_mut(&mut self, extra: &mut Stack) -> u32; +} + pub trait Stack { fn push(&mut self, bits: u32); fn pop(&mut self) -> u32; @@ -364,6 +370,61 @@ macro_rules! stack_closures { } } } + + impl<'a, $($var,)* R> ToRefMutWasmBoundary for FnMut($($var),*) -> R + 'a + where $($var: WasmAbi + WasmDescribe,)* + R: WasmAbi + WasmDescribe + { + type Abi = u32; + + fn to_abi_ref_mut(&mut self, extra: &mut Stack) -> u32 { + #[allow(non_snake_case)] + unsafe extern fn invoke<$($var,)* R>( + a: usize, + b: usize, + $($var: $var),* + ) -> R { + if a == 0 { + throw("closure invoked recursively or destroyed already"); + } + let f: &mut FnMut($($var),*) -> R = mem::transmute((a, b)); + f($($var),*) + } + unsafe { + let (a, b): (usize, usize) = mem::transmute(self); + extra.push(a as u32); + extra.push(b as u32); + invoke::<$($var,)* R> as u32 + } + } + } + + impl<'a, $($var,)*> ToRefMutWasmBoundary for FnMut($($var),*) + 'a + where $($var: WasmAbi + WasmDescribe,)* + { + type Abi = u32; + + fn to_abi_ref_mut(&mut self, extra: &mut Stack) -> u32 { + #[allow(non_snake_case)] + unsafe extern fn invoke<$($var,)* >( + a: usize, + b: usize, + $($var: $var),* + ) { + if a == 0 { + throw("closure invoked recursively or destroyed already"); + } + let f: &mut FnMut($($var),*) = mem::transmute((a, b)); + f($($var),*) + } + unsafe { + let (a, b): (usize, usize) = mem::transmute(self); + extra.push(a as u32); + extra.push(b as u32); + invoke::<$($var,)*> as u32 + } + } + } )*) } diff --git a/tests/all/closures.rs b/tests/all/closures.rs index 78dafa43..4abc5cbd 100644 --- a/tests/all/closures.rs +++ b/tests/all/closures.rs @@ -327,3 +327,109 @@ fn long_fnmut_recursive() { "#) .test(); } + +#[test] +fn fnmut() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use std::cell::Cell; + use wasm_bindgen::prelude::*; + + #[wasm_bindgen(module = "./test")] + extern { + fn call(a: &mut FnMut()); + fn thread(a: &mut FnMut(u32) -> u32) -> u32; + } + + #[wasm_bindgen] + pub fn run() { + let mut a = false; + call(&mut || a = true); + assert!(a); + + let mut x = false; + assert_eq!(thread(&mut |a| { + x = true; + a + 1 + }), 3); + assert!(x); + } + "#) + .file("test.ts", r#" + import { run } from "./out"; + + export function call(a: any) { + a(); + } + + export function thread(a: any) { + return a(2); + } + + export function test() { + run(); + } + "#) + .test(); +} + +#[test] +fn fnmut_bad() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use std::cell::Cell; + use wasm_bindgen::prelude::*; + + #[wasm_bindgen(module = "./test")] + extern { + fn call(a: &mut FnMut()); + #[wasm_bindgen(catch)] + fn again(a: bool) -> Result<(), JsValue>; + } + + #[wasm_bindgen] + pub fn run() { + let mut x = true; + let mut hits = 0; + call(&mut || { + hits += 1; + if again(hits == 1).is_err() { + return + } + x = false; + }); + assert!(hits == 1); + assert!(x); + + assert!(again(true).is_err()); + } + "#) + .file("test.ts", r#" + import { run } from "./out"; + + let F: any = null; + + export function call(a: any) { + F = a; + a(); + } + + export function again(x: boolean) { + if (x) F(); + } + + export function test() { + run(); + } + "#) + .test(); +} +