diff --git a/DESIGN.md b/DESIGN.md index 0db86779..aba2e513 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -536,31 +536,37 @@ available to JS through generated shims. If we take a look at the generated JS code for this we'll see: ```js -import * as wasm from './foo_bg'; +import * as wasm from './js_hello_world_bg'; export class Foo { - constructor(ptr) { - this.ptr = ptr; - } + static __construct(ptr) { + return new Foo(ptr); + } - free() { - const ptr = this.ptr; - this.ptr = 0; - wasm.__wbindgen_foo_free(ptr); - } + constructor(ptr) { + this.ptr = ptr; + } - static new(arg0) { - const ret = wasm.foo_new(arg0); - return new Foo(ret); - } + free() { + const ptr = this.ptr; + this.ptr = 0; + wasm.__wbg_foo_free(ptr); + } - get() { - return wasm.foo_get(this.ptr); - } + static new(arg0) { + const ret = wasm.foo_new(arg0); + return Foo.__construct(ret) + } - set(arg0) { - wasm.foo_set(this.ptr, arg0); - } + get() { + const ret = wasm.foo_get(this.ptr); + return ret; + } + + set(arg0) { + const ret = wasm.foo_set(this.ptr, arg0); + return ret; + } } ``` @@ -573,9 +579,7 @@ to JS: * Manual memory management is exposed in JS as well. The `free` function is required to be invoked to deallocate resources on the Rust side of things. -It's intended that `new Foo()` is never used in JS. When `wasm-bindgen` is run -with `--debug` it'll actually emit assertions to this effect to ensure that -instances of `Foo` are only constructed with the functions like `Foo.new` in JS. +To be able to use `new Foo()`, you'd need to annotate `new` as `#[wasm_bindgen(constructor)]`. One important aspect to note here, though, is that once `free` is called the JS object is "neutered" in that its internal pointer is nulled out. This means that diff --git a/crates/backend/src/ast.rs b/crates/backend/src/ast.rs index 397af170..cc3829c0 100644 --- a/crates/backend/src/ast.rs +++ b/crates/backend/src/ast.rs @@ -14,6 +14,7 @@ pub struct Export { pub class: Option, pub method: bool, pub mutable: bool, + pub constructor: Option, pub function: Function, } @@ -116,6 +117,7 @@ impl Program { class: None, method: false, mutable: false, + constructor: None, function: Function::from(f, opts), }); } @@ -126,8 +128,8 @@ impl Program { } syn::Item::Impl(mut i) => { let opts = opts.unwrap_or_else(|| BindgenAttrs::find(&mut i.attrs)); + self.push_impl(&mut i, opts); i.to_tokens(tokens); - self.push_impl(i, opts); } syn::Item::ForeignMod(mut f) => { let opts = opts.unwrap_or_else(|| BindgenAttrs::find(&mut f.attrs)); @@ -145,7 +147,7 @@ impl Program { } } - pub fn push_impl(&mut self, item: syn::ItemImpl, _opts: BindgenAttrs) { + pub fn push_impl(&mut self, item: &mut syn::ItemImpl, _opts: BindgenAttrs) { if item.defaultness.is_some() { panic!("default impls are not supported"); } @@ -168,16 +170,16 @@ impl Program { }, _ => panic!("unsupported self type in impl"), }; - for item in item.items.into_iter() { - self.push_impl_item(name, item); + for mut item in item.items.iter_mut() { + self.push_impl_item(name, &mut item); } } - fn push_impl_item(&mut self, class: syn::Ident, item: syn::ImplItem) { - let mut method = match item { + fn push_impl_item(&mut self, class: syn::Ident, item: &mut syn::ImplItem) { + let method = match item { syn::ImplItem::Const(_) => panic!("const definitions aren't supported"), syn::ImplItem::Type(_) => panic!("type definitions in impls aren't supported"), - syn::ImplItem::Method(m) => m, + syn::ImplItem::Method(ref mut m) => m, syn::ImplItem::Macro(_) => panic!("macros in impls aren't supported"), syn::ImplItem::Verbatim(_) => panic!("unparsed impl item?"), }; @@ -196,19 +198,27 @@ impl Program { } let opts = BindgenAttrs::find(&mut method.attrs); + let is_constructor = opts.constructor(); + let constructor = if is_constructor { + Some(method.sig.ident.to_string()) + } else { + None + }; let (function, mutable) = Function::from_decl( method.sig.ident, - Box::new(method.sig.decl), - method.attrs, + Box::new(method.sig.decl.clone()), + method.attrs.clone(), opts, - method.vis, + method.vis.clone(), true, ); + self.exports.push(Export { class: Some(class), method: mutable.is_some(), mutable: mutable.unwrap_or(false), + constructor, function, }); } @@ -536,6 +546,7 @@ impl Export { shared::Export { class: self.class.map(|s| s.as_ref().to_string()), method: self.method, + constructor: self.constructor.clone(), function: self.function.shared(), } } @@ -764,6 +775,7 @@ impl syn::synom::Synom for BindgenAttrs { )); } +#[derive(PartialEq)] enum BindgenAttr { Catch, Constructor, diff --git a/crates/cli-support/src/js.rs b/crates/cli-support/src/js.rs index 181988fe..9c85f38e 100644 --- a/crates/cli-support/src/js.rs +++ b/crates/cli-support/src/js.rs @@ -29,6 +29,7 @@ pub struct Context<'a> { pub struct ExportedClass { pub contents: String, pub typescript: String, + pub constructor: Option, } pub struct SubContext<'a, 'b: 'a> { @@ -316,42 +317,57 @@ impl<'a> Context<'a> { ts_dst.push_str(" public ptr: number; "); - if self.config.debug { - self.expose_check_token(); - dst.push_str(&format!(" - constructor(ptr, sym) {{ - _checkToken(sym); - this.ptr = ptr; - }} - ")); - ts_dst.push_str("constructor(ptr: number, sym: Symbol);\n"); - let new_name = shared::new_function(&class); - if self.wasm_import_needed(&new_name) { - self.expose_add_heap_object(); - self.export(&new_name, &format!(" - function(ptr) {{ - return addHeapObject(new {class}(ptr, token)); + if self.config.debug || exports.constructor.is_some() { + self.expose_constructor_token(); + + dst.push_str(&format!(" + static __construct(ptr) {{ + return new {}(new ConstructorToken(ptr)); + }} + + constructor(...args) {{ + if (args.length === 1 && args[0] instanceof ConstructorToken) {{ + this.ptr = args[0].ptr; + return; }} - ", class = class)); + ", class)); + + if let Some(constructor) = exports.constructor { + ts_dst.push_str(&format!("constructor(...args: [any]);\n")); + + dst.push_str(&format!(" + // This invocation of new will call this constructor with a ConstructorToken + let instance = {class}.{constructor}(...args); + this.ptr = instance.ptr; + ", class = class, constructor = constructor)); + } else { + dst.push_str("throw new Error('you cannot invoke `new` directly without having a \ + method annotated a constructor');"); } + + dst.push_str("}"); } else { dst.push_str(&format!(" + static __construct(ptr) {{ + return new {}(ptr); + }} + constructor(ptr) {{ this.ptr = ptr; }} - ")); - ts_dst.push_str("constructor(ptr: number);\n"); + ", class)); + } - let new_name = shared::new_function(&class); - if self.wasm_import_needed(&new_name) { - self.expose_add_heap_object(); - self.export(&new_name, &format!(" - function(ptr) {{ - return addHeapObject(new {class}(ptr)); - }} - ", class = class)); - } + let new_name = shared::new_function(&class); + if self.wasm_import_needed(&new_name) { + self.expose_add_heap_object(); + + self.export(&new_name, &format!(" + function(ptr) {{ + return addHeapObject({}.__construct(ptr)); + }} + ", class)); } dst.push_str(&format!(" @@ -589,19 +605,6 @@ impl<'a> Context<'a> { ", get_obj)); } - fn expose_check_token(&mut self) { - if !self.exposed_globals.insert("check_token") { - return; - } - self.globals.push_str(&format!(" - const token = Symbol('foo'); - function _checkToken(sym) {{ - if (token !== sym) - throw new Error('cannot invoke `new` directly'); - }} - ")); - } - fn expose_assert_num(&mut self) { if !self.exposed_globals.insert("assert_num") { return; @@ -765,6 +768,20 @@ impl<'a> Context<'a> { ")); } + fn expose_constructor_token(&mut self) { + if !self.exposed_globals.insert("ConstructorToken") { + return; + } + + self.globals.push_str(" + class ConstructorToken { + constructor(ptr) { + this.ptr = ptr; + } + } + "); + } + fn expose_get_string_from_wasm(&mut self) { if !self.exposed_globals.insert("get_string_from_wasm") { return; @@ -1236,11 +1253,8 @@ impl<'a> Context<'a> { if let Some(name) = ty.rust_struct() { dst_ts.push_str(": "); dst_ts.push_str(name); - return if self.config.debug { - format!("return new {name}(ret, token);", name = name) - } else { - format!("return new {name}(ret);", name = name) - } + + return format!("return {}.__construct(ret)",&name); } if ty.is_number() { @@ -1324,8 +1338,8 @@ impl<'a, 'b> SubContext<'a, 'b> { self.cx.typescript.push_str("\n"); } - pub fn generate_export_for_class(&mut self, class: &str, export: &shared::Export) { - let wasm_name = shared::struct_function_export_name(class, &export.function.name); + pub fn generate_export_for_class(&mut self, class_name: &str, export: &shared::Export) { + let wasm_name = shared::struct_function_export_name(class_name, &export.function.name); let descriptor = self.cx.describe(&wasm_name); let (js, ts) = self.generate_function( "", @@ -1334,12 +1348,26 @@ impl<'a, 'b> SubContext<'a, 'b> { export.method, &descriptor.unwrap_function(), ); - let class = self.cx.exported_classes.entry(class.to_string()) + + let class = self.cx.exported_classes.entry(class_name.to_string()) .or_insert(ExportedClass::default()); if !export.method { class.contents.push_str("static "); class.typescript.push_str("static "); } + + let constructors: Vec = self.program.exports + .iter() + .filter(|x| x.class == Some(class_name.to_string())) + .filter_map(|x| x.constructor.clone()) + .collect(); + + class.constructor = match constructors.len() { + 0 => None, + 1 => Some(constructors[0].clone()), + x @ _ => panic!("There must be only one constructor, not {}", x), + }; + class.contents.push_str(&export.function.name); class.contents.push_str(&js); class.contents.push_str("\n"); @@ -1560,15 +1588,11 @@ impl<'a, 'b> SubContext<'a, 'b> { continue } - if let Some(s) = arg.rust_struct() { + if let Some(class) = arg.rust_struct() { if arg.is_by_ref() { panic!("cannot invoke JS functions with custom ref types yet") } - let assign = if self.cx.config.debug { - format!("let c{0} = new {class}(arg{0}, token);", i, class = s) - } else { - format!("let c{0} = new {class}(arg{0});", i, class = s) - }; + let assign = format!("let c{0} = {1}.__construct(arg{0});", i, class); extra.push_str(&assign); invoc_args.push(format!("c{}", i)); continue diff --git a/crates/shared/src/lib.rs b/crates/shared/src/lib.rs index 3d73ad63..b46f806b 100644 --- a/crates/shared/src/lib.rs +++ b/crates/shared/src/lib.rs @@ -54,6 +54,7 @@ pub struct ImportType { pub struct Export { pub class: Option, pub method: bool, + pub constructor: Option, pub function: Function, } diff --git a/tests/all/classes.rs b/tests/all/classes.rs index 396dd575..ad3b181c 100644 --- a/tests/all/classes.rs +++ b/tests/all/classes.rs @@ -17,6 +17,7 @@ fn simple() { #[wasm_bindgen] impl Foo { + #[wasm_bindgen(constructor)] pub fn new() -> Foo { Foo::with_contents(0) } @@ -47,6 +48,10 @@ fn simple() { assert.strictEqual(r2.add(2), 13); assert.strictEqual(r2.add(3), 16); r2.free(); + + const r3 = new Foo(); + assert.strictEqual(r3.add(42), 42); + r3.free(); } "#) .test(); @@ -361,3 +366,82 @@ fn pass_into_js_as_js_class() { "#) .test(); } + + + +#[test] +fn constructors() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + #[wasm_bindgen] + pub fn cross_item_construction() -> Bar { + Bar::other_name(7, 8) + } + + #[wasm_bindgen] + pub struct Foo { + number: u32, + } + + #[wasm_bindgen] + impl Foo { + #[wasm_bindgen(constructor)] + pub fn new(number: u32) -> Foo { + Foo { number } + } + + pub fn get_number(&self) -> u32 { + self.number + } + } + + #[wasm_bindgen] + pub struct Bar { + number: u32, + number2: u32, + } + + #[wasm_bindgen] + impl Bar { + #[wasm_bindgen(constructor)] + pub fn other_name(number: u32, number2: u32) -> Bar { + Bar { number, number2 } + } + + pub fn get_sum(&self) -> u32 { + self.number + self.number2 + } + } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import { Foo, Bar, cross_item_construction } from "./out"; + + export function test() { + const foo = new Foo(1); + assert.strictEqual(foo.get_number(), 1); + foo.free(); + + const foo2 = Foo.new(2); + assert.strictEqual(foo2.get_number(), 2); + foo2.free(); + + const bar = new Bar(3, 4); + assert.strictEqual(bar.get_sum(), 7); + bar.free(); + + const bar2 = Bar.other_name(5, 6); + assert.strictEqual(bar2.get_sum(), 11); + bar2.free(); + + assert.strictEqual(cross_item_construction().get_sum(), 15); + } + "#) + .test(); +}