diff --git a/lib/llvm-backend/cpp/object_loader.hh b/lib/llvm-backend/cpp/object_loader.hh index d31d33266..9a6fe5681 100644 --- a/lib/llvm-backend/cpp/object_loader.hh +++ b/lib/llvm-backend/cpp/object_loader.hh @@ -1,6 +1,9 @@ #include #include #include +#include +#include +#include typedef enum { PROTECT_NONE, @@ -24,6 +27,8 @@ typedef uintptr_t (*lookup_vm_symbol_t)(const char* name_ptr, size_t length); typedef void (*fde_visitor_t)(uint8_t *fde); typedef result_t (*visit_fde_t)(uint8_t *fde, size_t size, fde_visitor_t visitor); +typedef void (*trampoline_t)(void*, void*, void*, void*); + typedef struct { /* Memory management. */ alloc_memory_t alloc_memory; @@ -35,6 +40,59 @@ typedef struct { visit_fde_t visit_fde; } callbacks_t; +class WasmException { +public: + virtual std::string description() const noexcept = 0; +}; + +struct UncatchableException : WasmException { +public: + enum class Type { + Unreachable, + IncorrectCallIndirectSignature, + Unknown, + } type; + + UncatchableException(Type type) : type(type) {} + + virtual std::string description() const noexcept override { + std::ostringstream ss; + ss + << "Uncatchable exception:" << '\n' + << " - type: " << type << '\n'; + + return ss.str(); + } + +private: + friend std::ostream& operator<<(std::ostream& out, const Type& ty) { + switch (ty) { + case Type::Unreachable: + out << "unreachable"; + break; + case Type::IncorrectCallIndirectSignature: + out << "incorrect call_indirect signature"; + break; + case Type::Unknown: + out << "unknown"; + break; + } + return out; + } +}; + +struct CatchableException : WasmException { +public: + CatchableException(uint32_t type_id, uint32_t value_num) : type_id(type_id), value_num(value_num) {} + + virtual std::string description() const noexcept override { + return "catchable exception"; + } + + uint32_t type_id, value_num; + uint64_t values[]; +}; + class WasmModule { public: WasmModule( @@ -57,10 +115,27 @@ extern "C" { return RESULT_OK; } + void throw_unreachable_exception() { + throw UncatchableException(UncatchableException::Type::Unreachable); + } + void throw_incorrect_call_indirect_signature() { + throw UncatchableException(UncatchableException::Type::IncorrectCallIndirectSignature); + } + void module_delete(WasmModule* module) { delete module; } + void invoke_trampoline(trampoline_t trampoline, void* ctx, void* func, void* params, void* results) { + try { + trampoline(ctx, func, params, results); + } catch(const WasmException& e) { + std::cout << e.description() << std::endl; + } catch (...) { + std::cout << "unknown exception" << std::endl; + } + } + void* get_func_symbol(WasmModule* module, const char* name) { return module->get_func(llvm::StringRef(name)); } diff --git a/lib/llvm-backend/src/backend.rs b/lib/llvm-backend/src/backend.rs index 66f29cbbb..ebffd914e 100644 --- a/lib/llvm-backend/src/backend.rs +++ b/lib/llvm-backend/src/backend.rs @@ -71,6 +71,16 @@ extern "C" { ) -> LLVMResult; fn module_delete(module: *mut LLVMModule); fn get_func_symbol(module: *mut LLVMModule, name: *const c_char) -> *const vm::Func; + fn throw_unreachable_exception(); + fn throw_incorrect_call_indirect_signature(); + // invoke_trampoline(trampoline_t trampoline, void* ctx, void* func, void* params, void* results) + fn invoke_trampoline( + trampoline: unsafe extern "C" fn(*mut vm::Ctx, *const vm::Func, *const u64, *mut u64), + vmctx_ptr: *mut vm::Ctx, + func_ptr: *const vm::Func, + params: *const u64, + results: *mut u64, + ); } fn get_callbacks() -> Callbacks { @@ -162,6 +172,11 @@ fn get_callbacks() -> Callbacks { fn_name!("vm.memory.size.dynamic.local") => vmcalls::local_dynamic_memory_size as _, fn_name!("vm.memory.grow.static.local") => vmcalls::local_static_memory_grow as _, fn_name!("vm.memory.size.static.local") => vmcalls::local_static_memory_size as _, + + fn_name!("vm.exception.throw.unreachable") => throw_unreachable_exception as _, + fn_name!("vm.exception.throw.incorrect-call_indirect_signature") => { + throw_incorrect_call_indirect_signature as _ + } _ => ptr::null(), } } @@ -343,7 +358,8 @@ impl ProtectedCaller for LLVMProtectedCaller { // Here we go. unsafe { - trampoline( + invoke_trampoline( + trampoline, vmctx_ptr, func_ptr, param_vec.as_ptr(), diff --git a/lib/llvm-backend/src/code.rs b/lib/llvm-backend/src/code.rs index a9170082b..00ad353c3 100644 --- a/lib/llvm-backend/src/code.rs +++ b/lib/llvm-backend/src/code.rs @@ -539,6 +539,9 @@ fn parse_function( // Emit an unreachable instruction. // If llvm cannot prove that this is never touched, // it will emit a `ud2` instruction on x86_64 arches. + + builder.build_call(intrinsics.throw_unreachable, &[], "throw"); + ctx.build_trap(); builder.build_unreachable(); @@ -731,13 +734,12 @@ fn parse_function( ) }; - // let sigindices_equal = builder.build_int_compare( - // IntPredicate::EQ, - // expected_dynamic_sigindex, - // found_dynamic_sigindex, - // "sigindices_equal", - // ); - let sigindices_equal = intrinsics.i1_ty.const_int(1, false); + let sigindices_equal = builder.build_int_compare( + IntPredicate::EQ, + expected_dynamic_sigindex, + found_dynamic_sigindex, + "sigindices_equal", + ); // Tell llvm that `expected_dynamic_sigindex` should equal `found_dynamic_sigindex`. let sigindices_equal = builder @@ -764,7 +766,11 @@ fn parse_function( ); builder.position_at_end(&sigindices_notequal_block); - ctx.build_trap(); + builder.build_call( + intrinsics.throw_incorrect_call_indirect_signature, + &[], + "throw", + ); builder.build_unreachable(); builder.position_at_end(&continue_block); diff --git a/lib/llvm-backend/src/intrinsics.rs b/lib/llvm-backend/src/intrinsics.rs index c3f2191f3..a6f7cb02e 100644 --- a/lib/llvm-backend/src/intrinsics.rs +++ b/lib/llvm-backend/src/intrinsics.rs @@ -105,6 +105,9 @@ pub struct Intrinsics { pub memory_size_static_import: FunctionValue, pub memory_size_shared_import: FunctionValue, + pub throw_unreachable: FunctionValue, + pub throw_incorrect_call_indirect_signature: FunctionValue, + ctx_ty: StructType, pub ctx_ptr_ty: PointerType, } @@ -344,7 +347,16 @@ impl Intrinsics { ret_i32_take_ctx_i32, None, ), - + throw_unreachable: module.add_function( + "vm.exception.throw.unreachable", + void_ty.fn_type(&[], false), + None, + ), + throw_incorrect_call_indirect_signature: module.add_function( + "vm.exception.throw.incorrect-call_indirect_signature", + void_ty.fn_type(&[], false), + None, + ), ctx_ty, ctx_ptr_ty, } diff --git a/lib/runtime/examples/call.rs b/lib/runtime/examples/call.rs index ac38b8d46..0684c1bb7 100644 --- a/lib/runtime/examples/call.rs +++ b/lib/runtime/examples/call.rs @@ -6,7 +6,7 @@ static WAT: &'static str = r#" (module (type (;0;) (func (param i32) (result i32))) (func (;0;) (type 0) (param i32) (result i32) - i32.const 42) + unreachable) (export "select_trap_l" (func 0)) ) "#;