From 28c036129037df4affd7a5a1967ceb2a05e1ecff Mon Sep 17 00:00:00 2001 From: Sergey Pepyakin Date: Mon, 8 Jan 2018 15:10:50 +0300 Subject: [PATCH] Validate before instantiate --- examples/interpret.rs | 4 +++- examples/invoke.rs | 4 +++- examples/tictactoe.rs | 33 +++++++++++++++++++++++-------- src/interpreter/module.rs | 40 +++++++++++++++++++++----------------- src/interpreter/program.rs | 4 +++- src/lib.rs | 2 ++ src/validation/mod.rs | 27 +++++++++++++++++++++---- src/validation/tests.rs | 40 +++++++++++++++++++------------------- 8 files changed, 101 insertions(+), 53 deletions(-) diff --git a/examples/interpret.rs b/examples/interpret.rs index ae11d2b..ca334f6 100644 --- a/examples/interpret.rs +++ b/examples/interpret.rs @@ -17,12 +17,14 @@ fn main() { // `deserialize_file` function (which works only with modules) let module = parity_wasm::deserialize_file(&args[1]).expect("Failed to load module"); + let validated_module = parity_wasm::validate_module(module).expect("Failed to validate module"); + // Intialize deserialized module. It adds module into It expects 3 parameters: // - a name for the module // - a module declaration // - "main" module doesn't import native module(s) this is why we don't need to provide external native modules here // This test shows how to implement native module https://github.com/NikVolf/parity-wasm/blob/master/src/interpreter/tests/basics.rs#L197 - let main = ModuleInstance::new(&module) + let main = ModuleInstance::new(&validated_module) .run_start(&mut EmptyExternals) .expect("Failed to initialize module"); diff --git a/examples/invoke.rs b/examples/invoke.rs index 50326eb..276f526 100644 --- a/examples/invoke.rs +++ b/examples/invoke.rs @@ -69,12 +69,14 @@ fn main() { }).collect::>() }; + let validated_module = parity_wasm::validate_module(module).expect("Module to be valid"); + // Intialize deserialized module. It adds module into It expects 3 parameters: // - a name for the module // - a module declaration // - "main" module doesn't import native module(s) this is why we don't need to provide external native modules here // This test shows how to implement native module https://github.com/NikVolf/parity-wasm/blob/master/src/interpreter/tests/basics.rs#L197 - let main = ModuleInstance::new(&module) + let main = ModuleInstance::new(&validated_module) .run_start(&mut EmptyExternals) .expect("Failed to initialize module"); diff --git a/examples/tictactoe.rs b/examples/tictactoe.rs index b0de699..197d7a0 100644 --- a/examples/tictactoe.rs +++ b/examples/tictactoe.rs @@ -3,18 +3,22 @@ extern crate parity_wasm; use std::env; use std::fmt; use std::rc::Rc; -use parity_wasm::elements::{Module, FunctionType, ValueType, TableType, GlobalType, MemoryType}; +use parity_wasm::elements::{FunctionType, ValueType, TableType, GlobalType, MemoryType}; use parity_wasm::interpreter::{ Error as InterpreterError, ModuleInstance, UserError, HostFuncIndex, Externals, RuntimeValue, GlobalInstance, TableInstance, MemoryInstance, - TableRef, MemoryRef, GlobalRef, FuncRef, TryInto, ImportResolver, FuncInstance, + TableRef, MemoryRef, GlobalRef, FuncRef, TryInto, ImportResolver, FuncInstance }; +use parity_wasm::elements::{Error as DeserializationError}; +use parity_wasm::ValidationError; #[derive(Debug)] pub enum Error { OutOfRange, AlreadyOccupied, Interpreter(InterpreterError), + Deserialize(DeserializationError), + Validation(ValidationError), } impl fmt::Display for Error { @@ -29,6 +33,18 @@ impl From for Error { } } +impl From for Error { + fn from(e: DeserializationError) -> Error { + Error::Deserialize(e) + } +} + +impl From for Error { + fn from(e: ValidationError) -> Error { + Error::Validation(e) + } +} + impl UserError for Error {} mod tictactoe { @@ -243,9 +259,12 @@ impl<'a> ImportResolver for RuntimeImportResolver { } fn instantiate( - module: &Module, + path: &str, ) -> Result, Error> { - let instance = ModuleInstance::new(module) + let module = parity_wasm::deserialize_file(path)?; + let validated_module = parity_wasm::validate_module(module)?; + + let instance = ModuleInstance::new(&validated_module) .with_import("env", &RuntimeImportResolver) .assert_no_start()?; @@ -291,12 +310,10 @@ fn main() { println!("Usage: {} ", args[0]); return; } - let x_module = parity_wasm::deserialize_file(&args[1]).expect("X player module to load"); - let o_module = parity_wasm::deserialize_file(&args[2]).expect("Y player module to load"); // Instantiate modules of X and O players. - let x_instance = instantiate(&x_module).unwrap(); - let o_instance = instantiate(&o_module).unwrap(); + let x_instance = instantiate(&args[1]).expect("X player module to load"); + let o_instance = instantiate(&args[2]).expect("Y player module to load"); let result = play(x_instance, o_instance, &mut game); println!("result = {:?}, game = {:#?}", result, game); diff --git a/src/interpreter/module.rs b/src/interpreter/module.rs index b39890c..3687c33 100644 --- a/src/interpreter/module.rs +++ b/src/interpreter/module.rs @@ -3,7 +3,7 @@ use std::cell::RefCell; use std::fmt; use std::collections::HashMap; use std::borrow::Cow; -use elements::{External, FunctionType, GlobalType, InitExpr, Internal, MemoryType, Module, Opcode, +use elements::{External, FunctionType, GlobalType, InitExpr, Internal, MemoryType, Opcode, ResizableLimits, TableType, Type}; use interpreter::{Error, MemoryInstance, RuntimeValue, TableInstance}; use interpreter::imports::{ImportResolver, Imports}; @@ -12,7 +12,7 @@ use interpreter::func::{FuncRef, FuncBody, FuncInstance}; use interpreter::table::TableRef; use interpreter::memory::MemoryRef; use interpreter::host::Externals; -use validation::validate_module; +use validation::ValidatedModule; use common::{DEFAULT_MEMORY_INDEX, DEFAULT_TABLE_INDEX}; pub enum ExternVal { @@ -157,11 +157,12 @@ impl ModuleInstance { } fn alloc_module( - module: &Module, + validated_module: &ValidatedModule, extern_vals: &[ExternVal], instance: &Rc, ) -> Result<(), Error> { - let mut aux_data = validate_module(module)?; + let labels = validated_module.labels(); + let module = validated_module.module(); for &Type::Function(ref ty) in module.type_section().map(|ts| ts.types()).unwrap_or(&[]) { let type_id = alloc_func_type(ty.clone()); @@ -242,9 +243,9 @@ impl ModuleInstance { let func_type = instance.type_by_index(ty.type_ref()).expect( "Due to validation type should exists", ); - let labels = aux_data.labels.remove(&index).expect( + let labels = labels.get(&index).expect( "At func validation time labels are collected; Collected labels are added by index; qed", - ); + ).clone(); let func_body = FuncBody { locals: body.locals().to_vec(), opcodes: body.code().clone(), @@ -316,12 +317,13 @@ impl ModuleInstance { } fn instantiate_with_externvals( - module: &Module, + validated_module: &ValidatedModule, extern_vals: &[ExternVal], ) -> Result, Error> { + let module = validated_module.module(); let instance = Rc::new(ModuleInstance::default()); - ModuleInstance::alloc_module(module, extern_vals, &instance)?; + ModuleInstance::alloc_module(validated_module, extern_vals, &instance)?; for element_segment in module.elements_section().map(|es| es.entries()).unwrap_or( &[], @@ -360,9 +362,11 @@ impl ModuleInstance { } fn instantiate_with_imports( - module: &Module, + validated_module: &ValidatedModule, imports: &Imports, ) -> Result, Error> { + let module = validated_module.module(); + let mut extern_vals = Vec::new(); for import_entry in module.import_section().map(|s| s.entries()).unwrap_or(&[]) { let module_name = import_entry.module(); @@ -398,10 +402,10 @@ impl ModuleInstance { extern_vals.push(extern_val); } - Self::instantiate_with_externvals(module, &extern_vals) + Self::instantiate_with_externvals(validated_module, &extern_vals) } - pub fn new<'a>(module: &'a Module) -> InstantiationBuilder<'a> { + pub fn new<'a>(module: &'a ValidatedModule) -> InstantiationBuilder<'a> { InstantiationBuilder::new(module) } @@ -446,14 +450,14 @@ impl ModuleInstance { } pub struct InstantiationBuilder<'a> { - module: &'a Module, + validated_module: &'a ValidatedModule, imports: Option>, } impl<'a> InstantiationBuilder<'a> { - fn new(module: &'a Module) -> Self { + fn new(validated_module: &'a ValidatedModule) -> Self { InstantiationBuilder { - module, + validated_module, imports: None, } } @@ -476,9 +480,9 @@ impl<'a> InstantiationBuilder<'a> { pub fn run_start<'b, E: Externals>(mut self, state: &'b mut E) -> Result, Error> { let imports = self.imports.get_or_insert_with(|| Imports::default()); - let instance = ModuleInstance::instantiate_with_imports(self.module, imports)?; + let instance = ModuleInstance::instantiate_with_imports(self.validated_module, imports)?; - if let Some(start_fn_idx) = self.module.start_section() { + if let Some(start_fn_idx) = self.validated_module.module().start_section() { let start_func = instance.func_by_index(start_fn_idx).expect( "Due to validation start function should exists", ); @@ -488,9 +492,9 @@ impl<'a> InstantiationBuilder<'a> { } pub fn assert_no_start(mut self) -> Result, Error> { - assert!(self.module.start_section().is_none()); + assert!(self.validated_module.module().start_section().is_none()); let imports = self.imports.get_or_insert_with(|| Imports::default()); - let instance = ModuleInstance::instantiate_with_imports(self.module, imports)?; + let instance = ModuleInstance::instantiate_with_imports(self.validated_module, imports)?; Ok(instance) } } diff --git a/src/interpreter/program.rs b/src/interpreter/program.rs index 08ebc91..28bfe58 100644 --- a/src/interpreter/program.rs +++ b/src/interpreter/program.rs @@ -8,6 +8,7 @@ use interpreter::func::{FuncInstance, FuncRef}; use interpreter::value::RuntimeValue; use interpreter::imports::{Imports, ImportResolver}; use interpreter::host::Externals; +use validation::validate_module; /// Program instance. Program is a set of instantiated modules. #[deprecated] @@ -40,7 +41,8 @@ impl ProgramInstance { for (module_name, import_resolver) in self.resolvers.iter() { imports.push_resolver(&**module_name, &**import_resolver); } - ModuleInstance::new(&module) + let validate_module = validate_module(module)?; + ModuleInstance::new(&validate_module) .with_imports(imports) .run_start(externals)? }; diff --git a/src/lib.rs b/src/lib.rs index 538738f..ea879e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,8 @@ pub use elements::{ peek_size, }; +pub use validation::{validate_module, ValidatedModule, Error as ValidationError}; + #[allow(deprecated)] pub use interpreter::{ ProgramInstance, diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 93d68e9..78cf6b2 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -29,11 +29,27 @@ impl From for Error { } } -pub struct AuxiliaryData { - pub labels: HashMap>, +#[derive(Clone)] +pub struct ValidatedModule { + labels: HashMap>, + module: Module, } -pub fn validate_module(module: &Module) -> Result { +impl<'a> ValidatedModule { + pub fn module(&self) -> &Module { + &self.module + } + + pub fn into_module(self) -> Module { + self.module + } + + pub(crate) fn labels(&self) -> &HashMap> { + &self.labels + } +} + +pub fn validate_module(module: Module) -> Result { let mut context_builder = ModuleContextBuilder::new(); let mut imported_globals = Vec::new(); let mut labels = HashMap::new(); @@ -233,7 +249,10 @@ pub fn validate_module(module: &Module) -> Result { } } - Ok(AuxiliaryData { labels }) + Ok(ValidatedModule { + module, + labels + }) } impl ResizableLimits { diff --git a/src/validation/tests.rs b/src/validation/tests.rs index 3e1f255..baa711f 100644 --- a/src/validation/tests.rs +++ b/src/validation/tests.rs @@ -8,7 +8,7 @@ use elements::{ #[test] fn empty_is_valid() { let module = module().build(); - assert!(validate_module(&module).is_ok()); + assert!(validate_module(module).is_ok()); } #[test] @@ -30,7 +30,7 @@ fn limits() { .with_max(max) .build() .build(); - assert_eq!(validate_module(&m).is_ok(), is_valid); + assert_eq!(validate_module(m).is_ok(), is_valid); // imported table let m = module() @@ -42,7 +42,7 @@ fn limits() { ) ) .build(); - assert_eq!(validate_module(&m).is_ok(), is_valid); + assert_eq!(validate_module(m).is_ok(), is_valid); // defined memory let m = module() @@ -51,7 +51,7 @@ fn limits() { .with_max(max) .build() .build(); - assert_eq!(validate_module(&m).is_ok(), is_valid); + assert_eq!(validate_module(m).is_ok(), is_valid); // imported table let m = module() @@ -63,7 +63,7 @@ fn limits() { ) ) .build(); - assert_eq!(validate_module(&m).is_ok(), is_valid); + assert_eq!(validate_module(m).is_ok(), is_valid); } } @@ -79,7 +79,7 @@ fn global_init_const() { ) ) .build(); - assert!(validate_module(&m).is_ok()); + assert!(validate_module(m).is_ok()); // init expr type differs from declared global type let m = module() @@ -90,7 +90,7 @@ fn global_init_const() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); } #[test] @@ -110,7 +110,7 @@ fn global_init_global() { ) ) .build(); - assert!(validate_module(&m).is_ok()); + assert!(validate_module(m).is_ok()); // get_global can reference only previously defined globals let m = module() @@ -121,7 +121,7 @@ fn global_init_global() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); // get_global can reference only const globals let m = module() @@ -139,7 +139,7 @@ fn global_init_global() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); // get_global in init_expr can only refer to imported globals. let m = module() @@ -156,7 +156,7 @@ fn global_init_global() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); } #[test] @@ -170,7 +170,7 @@ fn global_init_misc() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); // empty init expr let m = module() @@ -181,7 +181,7 @@ fn global_init_misc() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); // not an constant opcode used let m = module() @@ -192,7 +192,7 @@ fn global_init_misc() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); } #[test] @@ -210,7 +210,7 @@ fn module_limits_validity() { .with_min(10) .build() .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); // module cannot contain more than 1 table atm. let m = module() @@ -225,7 +225,7 @@ fn module_limits_validity() { .with_min(10) .build() .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); } #[test] @@ -247,7 +247,7 @@ fn funcs() { ])).build() .build() .build(); - assert!(validate_module(&m).is_ok()); + assert!(validate_module(m).is_ok()); } #[test] @@ -262,7 +262,7 @@ fn globals() { ) ) .build(); - assert!(validate_module(&m).is_ok()); + assert!(validate_module(m).is_ok()); // import mutable global is invalid. let m = module() @@ -274,7 +274,7 @@ fn globals() { ) ) .build(); - assert!(validate_module(&m).is_err()); + assert!(validate_module(m).is_err()); } #[test] @@ -297,5 +297,5 @@ fn if_else_with_return_type_validation() { ])).build() .build() .build(); - validate_module(&m).unwrap(); + validate_module(m).unwrap(); }