diff --git a/crates/backend/src/ast.rs b/crates/backend/src/ast.rs index b29f140d..69a8a1f1 100644 --- a/crates/backend/src/ast.rs +++ b/crates/backend/src/ast.rs @@ -14,6 +14,7 @@ pub struct Export { pub class: Option, pub method: bool, pub mutable: bool, + pub constructor: bool, pub function: Function, } @@ -122,6 +123,7 @@ impl Program { class: None, method: false, mutable: false, + constructor: false, function: Function::from(f, opts), }); } @@ -132,8 +134,8 @@ impl Program { } syn::Item::Impl(mut i) => { let opts = opts.unwrap_or_else(|| BindgenAttrs::find(&mut i.attrs)); + self.push_impl(&mut i, opts); i.to_tokens(tokens); - self.push_impl(i, opts); } syn::Item::ForeignMod(mut f) => { let opts = opts.unwrap_or_else(|| BindgenAttrs::find(&mut f.attrs)); @@ -151,7 +153,7 @@ impl Program { } } - pub fn push_impl(&mut self, item: syn::ItemImpl, _opts: BindgenAttrs) { + pub fn push_impl(&mut self, item: &mut syn::ItemImpl, _opts: BindgenAttrs) { if item.defaultness.is_some() { panic!("default impls are not supported"); } @@ -174,16 +176,16 @@ impl Program { }, _ => panic!("unsupported self type in impl"), }; - for item in item.items.into_iter() { - self.push_impl_item(name, item); + for mut item in item.items.iter_mut() { + self.push_impl_item(name, &mut item); } } - fn push_impl_item(&mut self, class: syn::Ident, item: syn::ImplItem) { - let mut method = match item { + fn push_impl_item(&mut self, class: syn::Ident, item: &mut syn::ImplItem) { + let method = match item { syn::ImplItem::Const(_) => panic!("const definitions aren't supported"), syn::ImplItem::Type(_) => panic!("type definitions in impls aren't supported"), - syn::ImplItem::Method(m) => m, + syn::ImplItem::Method(ref mut m) => m, syn::ImplItem::Macro(_) => panic!("macros in impls aren't supported"), syn::ImplItem::Verbatim(_) => panic!("unparsed impl item?"), }; @@ -201,21 +203,37 @@ impl Program { panic!("can only bindgen safe functions"); } - let opts = BindgenAttrs::find(&mut method.attrs); + let mut opts = BindgenAttrs::find(&mut method.attrs); + let constructor = opts.constructor(); + if constructor { + if method.sig.ident != syn::Ident::from("new") { + panic!("The constructor must be called 'new' for now") + } + + let pos = opts.attrs + .iter() + .enumerate() + .find(|(_, a)| **a == BindgenAttr::Constructor) + .unwrap() + .0; + opts.attrs.remove(pos); + } let (function, mutable) = Function::from_decl( method.sig.ident, - Box::new(method.sig.decl), - method.attrs, + Box::new(method.sig.decl.clone()), + method.attrs.clone(), opts, - method.vis, + method.vis.clone(), true, false, ); + self.exports.push(Export { class: Some(class), method: mutable.is_some(), mutable: mutable.unwrap_or(false), + constructor, function, }); } @@ -548,6 +566,7 @@ impl Export { shared::Export { class: self.class.map(|s| s.as_ref().to_string()), method: self.method, + constructor: self.constructor, function: self.function.shared(), } } @@ -792,6 +811,7 @@ impl syn::synom::Synom for BindgenAttrs { )); } +#[derive(PartialEq)] enum BindgenAttr { Catch, Constructor, diff --git a/crates/cli-support/src/js.rs b/crates/cli-support/src/js.rs index 6b00b6b6..75f9c457 100644 --- a/crates/cli-support/src/js.rs +++ b/crates/cli-support/src/js.rs @@ -29,6 +29,7 @@ pub struct Context<'a> { pub struct ExportedClass { pub contents: String, pub typescript: String, + pub constructor: bool, } pub struct SubContext<'a, 'b: 'a> { @@ -316,42 +317,42 @@ impl<'a> Context<'a> { ts_dst.push_str(" public ptr: number; "); - if self.config.debug { - self.expose_check_token(); - dst.push_str(&format!(" - constructor(ptr, sym) {{ - _checkToken(sym); - this.ptr = ptr; + + self.expose_constructor_token(); + + dst.push_str(&format!(" + constructor(...args) {{ + if (args.length === 1 && args[0] instanceof ConstructorToken) {{ + this.ptr = args[0].ptr; + return; }} ")); - ts_dst.push_str("constructor(ptr: number, sym: Symbol);\n"); - let new_name = shared::new_function(&class); - if self.wasm_import_needed(&new_name) { - self.expose_add_heap_object(); - self.export(&new_name, &format!(" - function(ptr) {{ - return addHeapObject(new {class}(ptr, token)); - }} - ", class = class)); - } + if exports.constructor { + ts_dst.push_str(&format!("constructor(...args: [any] | [ConstructorToken]);\n")); + + dst.push_str(&format!(" + // This invocation of new will call this constructor with a ConstructorToken + let instance = {class}.new(...args); + this.ptr = instance.ptr; + ", class = class)); } else { - dst.push_str(&format!(" - constructor(ptr) {{ - this.ptr = ptr; - }} - ")); - ts_dst.push_str("constructor(ptr: number);\n"); + ts_dst.push_str(&format!("constructor(...args: [ConstructorToken]);\n")); - let new_name = shared::new_function(&class); - if self.wasm_import_needed(&new_name) { - self.expose_add_heap_object(); - self.export(&new_name, &format!(" - function(ptr) {{ - return addHeapObject(new {class}(ptr)); - }} - ", class = class)); - } + dst.push_str("throw new Error('you cannot invoke `new` directly without having a \ + method annotated a constructor');"); + } + + dst.push_str("}"); + + let new_name = shared::new_function(&class); + if self.wasm_import_needed(&new_name) { + self.expose_add_heap_object(); + self.export(&new_name, &format!(" + function(ptr) {{ + return addHeapObject(new {class}(new ConstructorToken(ptr))); + }} + ", class = class)); } dst.push_str(&format!(" @@ -589,19 +590,6 @@ impl<'a> Context<'a> { ", get_obj)); } - fn expose_check_token(&mut self) { - if !self.exposed_globals.insert("check_token") { - return; - } - self.globals.push_str(&format!(" - const token = Symbol('foo'); - function _checkToken(sym) {{ - if (token !== sym) - throw new Error('cannot invoke `new` directly'); - }} - ")); - } - fn expose_assert_num(&mut self) { if !self.exposed_globals.insert("assert_num") { return; @@ -765,6 +753,27 @@ impl<'a> Context<'a> { ")); } + fn expose_constructor_token(&mut self) { + if !self.exposed_globals.insert("ConstructorToken") { + return; + } + + self.globals.push_str(" + class ConstructorToken { + constructor(ptr) { + this.ptr = ptr; + } + } + "); + + self.typescript.push_str(" + class ConstructorToken { + constructor(ptr: number); + } + "); + + } + fn expose_get_string_from_wasm(&mut self) { if !self.exposed_globals.insert("get_string_from_wasm") { return; @@ -1236,11 +1245,7 @@ impl<'a> Context<'a> { if let Some(name) = ty.rust_struct() { dst_ts.push_str(": "); dst_ts.push_str(name); - return if self.config.debug { - format!("return new {name}(ret, token);", name = name) - } else { - format!("return new {name}(ret);", name = name) - } + return format!("return new {name}(new ConstructorToken(ret));", name = name); } if ty.is_number() { @@ -1322,19 +1327,26 @@ impl<'a, 'b> SubContext<'a, 'b> { self.cx.typescript.push_str("\n"); } - pub fn generate_export_for_class(&mut self, class: &str, export: &shared::Export) { + pub fn generate_export_for_class(&mut self, class_name: &str, export: &shared::Export) { let (js, ts) = self.generate_function( "", - &shared::struct_function_export_name(class, &export.function.name), + &shared::struct_function_export_name(class_name, &export.function.name), export.method, &export.function, ); - let class = self.cx.exported_classes.entry(class.to_string()) + + let class = self.cx.exported_classes.entry(class_name.to_string()) .or_insert(ExportedClass::default()); if !export.method { class.contents.push_str("static "); class.typescript.push_str("static "); } + + class.constructor = self.program.exports + .iter() + .filter(|x| x.class == Some(class_name.to_string())) + .any(|x| x.constructor); + class.contents.push_str(&export.function.name); class.contents.push_str(&js); class.contents.push_str("\n"); @@ -1560,11 +1572,7 @@ impl<'a, 'b> SubContext<'a, 'b> { if arg.is_by_ref() { panic!("cannot invoke JS functions with custom ref types yet") } - let assign = if self.cx.config.debug { - format!("let c{0} = new {class}(arg{0}, token);", i, class = s) - } else { - format!("let c{0} = new {class}(arg{0});", i, class = s) - }; + let assign = format!("let c{0} = new {class}(new ConstructorToken(arg{0}));", i, class = s); extra.push_str(&assign); invoc_args.push(format!("c{}", i)); continue diff --git a/crates/shared/src/lib.rs b/crates/shared/src/lib.rs index 3d73ad63..6e73fd07 100644 --- a/crates/shared/src/lib.rs +++ b/crates/shared/src/lib.rs @@ -54,6 +54,7 @@ pub struct ImportType { pub struct Export { pub class: Option, pub method: bool, + pub constructor: bool, pub function: Function, } diff --git a/tests/all/classes.rs b/tests/all/classes.rs index 396dd575..a323f850 100644 --- a/tests/all/classes.rs +++ b/tests/all/classes.rs @@ -17,6 +17,7 @@ fn simple() { #[wasm_bindgen] impl Foo { + #[wasm_bindgen(constructor)] pub fn new() -> Foo { Foo::with_contents(0) } @@ -47,6 +48,10 @@ fn simple() { assert.strictEqual(r2.add(2), 13); assert.strictEqual(r2.add(3), 16); r2.free(); + + const r3 = new Foo(); + assert.strictEqual(r3.add(42), 42); + r3.free(); } "#) .test(); @@ -361,3 +366,75 @@ fn pass_into_js_as_js_class() { "#) .test(); } + + + +#[test] +fn constructors() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + #[wasm_bindgen] + pub struct Foo { + number: u32, + } + + #[wasm_bindgen] + impl Foo { + #[wasm_bindgen(constructor)] + pub fn new(number: u32) -> Foo { + Foo { number } + } + + pub fn get_number(&self) -> u32 { + self.number + } + } + + #[wasm_bindgen] + pub struct Bar { + number: u32, + number2: u32, + } + + #[wasm_bindgen] + impl Bar { + #[wasm_bindgen(constructor)] + pub fn new(number: u32, number2: u32) -> Bar { + Bar { number, number2 } + } + + pub fn get_sum(&self) -> u32 { + self.number + self.number2 + } + } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import { Foo, Bar } from "./out"; + + export function test() { + const foo = new Foo(1); + assert.strictEqual(foo.get_number(), 1); + foo.free(); + + const foo2 = Foo.new(2); + assert.strictEqual(foo2.get_number(), 2); + foo2.free(); + + const bar = new Bar(3, 4); + assert.strictEqual(bar.get_sum(), 7); + bar.free(); + + const bar2 = Bar.new(5, 6); + assert.strictEqual(bar2.get_sum(), 11); + bar2.free(); + } + "#) + .test(); +}