Support mutable globals. Fixes #16

This commit is contained in:
Chad Retz 2018-07-20 14:43:54 -05:00
parent 9d87ce440f
commit 73e6b5769a
7 changed files with 117 additions and 45 deletions

View File

@ -13,4 +13,5 @@ public @interface WasmImport {
WasmExternalKind kind(); WasmExternalKind kind();
int resizableLimitInitial() default -1; int resizableLimitInitial() default -1;
int resizableLimitMaximum() default -1; int resizableLimitMaximum() default -1;
boolean globalSetter() default false;
} }

View File

@ -47,10 +47,13 @@ open class AstToAsm {
}) })
// Now all import globals as getter (and maybe setter) method handles // Now all import globals as getter (and maybe setter) method handles
ctx.cls.fields.addAll(ctx.importGlobals.mapIndexed { index, import -> ctx.cls.fields.addAll(ctx.importGlobals.mapIndexed { index, import ->
if ((import.kind as Node.Import.Kind.Global).type.mutable) throw CompileErr.MutableGlobalImport(index) val getter = FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalGetterFieldName(index),
FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalGetterFieldName(index),
MethodHandle::class.ref.asmDesc, null, null) MethodHandle::class.ref.asmDesc, null, null)
}) if (!(import.kind as Node.Import.Kind.Global).type.mutable) listOf(getter)
else listOf(getter, FieldNode(
Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalSetterFieldName(index),
MethodHandle::class.ref.asmDesc, null, null))
}.flatten())
// Now all non-import globals // Now all non-import globals
ctx.cls.fields.addAll(ctx.mod.globals.mapIndexed { index, global -> ctx.cls.fields.addAll(ctx.mod.globals.mapIndexed { index, global ->
val access = Opcodes.ACC_PRIVATE + if (!global.type.mutable) Opcodes.ACC_FINAL else 0 val access = Opcodes.ACC_PRIVATE + if (!global.type.mutable) Opcodes.ACC_FINAL else 0
@ -180,9 +183,11 @@ open class AstToAsm {
fun constructorImportTypes(ctx: ClsContext) = fun constructorImportTypes(ctx: ClsContext) =
ctx.importFuncs.map { MethodHandle::class.ref } + ctx.importFuncs.map { MethodHandle::class.ref } +
// We know it's only getters ctx.importGlobals.flatMap {
ctx.importGlobals.map { MethodHandle::class.ref } + // If it's mutable, it also comes with a setter
ctx.mod.imports.filter { it.kind is Node.Import.Kind.Table }.map { Array<MethodHandle>::class.ref } if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) listOf(MethodHandle::class.ref)
else listOf(MethodHandle::class.ref, MethodHandle::class.ref)
} + ctx.mod.imports.filter { it.kind is Node.Import.Kind.Table }.map { Array<MethodHandle>::class.ref }
fun toConstructorNode(ctx: ClsContext, func: Func) = mutableListOf<List<AnnotationNode>>().let { paramAnns -> fun toConstructorNode(ctx: ClsContext, func: Func) = mutableListOf<List<AnnotationNode>>().let { paramAnns ->
// If the first param is a mem class and imported, add annotation // If the first param is a mem class and imported, add annotation
@ -199,7 +204,15 @@ open class AstToAsm {
} }
// All non-mem imports one after another // All non-mem imports one after another
ctx.importFuncs.forEach { paramAnns.add(listOf(importAnnotation(ctx, it))) } ctx.importFuncs.forEach { paramAnns.add(listOf(importAnnotation(ctx, it))) }
ctx.importGlobals.forEach { paramAnns.add(listOf(importAnnotation(ctx, it))) } ctx.importGlobals.forEach {
paramAnns.add(listOf(importAnnotation(ctx, it)))
// There are two annotations here if it's mutable
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == true)
paramAnns.add(listOf(importAnnotation(ctx, it).also {
it.values.add("globalSetter")
it.values.add(true)
}))
}
ctx.mod.imports.forEach { ctx.mod.imports.forEach {
if (it.kind is Node.Import.Kind.Table) paramAnns.add(listOf(importAnnotation(ctx, it))) if (it.kind is Node.Import.Kind.Table) paramAnns.add(listOf(importAnnotation(ctx, it)))
} }
@ -240,14 +253,25 @@ open class AstToAsm {
} }
fun setConstructorGlobalImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) = fun setConstructorGlobalImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.importGlobals.indices.fold(func) { func, importIndex -> ctx.importGlobals.foldIndexed(func to ctx.importFuncs.size + paramsBeforeImports) {
importIndex, (func, importParamOffset), import ->
// Always a getter handle
func.addInsns( func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0), VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, ctx.importFuncs.size + importIndex + paramsBeforeImports + 1), VarInsnNode(Opcodes.ALOAD, importParamOffset + 1),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName, FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName,
ctx.importGlobalGetterFieldName(importIndex), MethodHandle::class.ref.asmDesc) ctx.importGlobalGetterFieldName(importIndex), MethodHandle::class.ref.asmDesc)
) ).let { func ->
} // If it's mutable, it has a second setter handle
if ((import.kind as? Node.Import.Kind.Global)?.type?.mutable == false) func to importParamOffset + 1
else func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, importParamOffset + 2),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName,
ctx.importGlobalSetterFieldName(importIndex), MethodHandle::class.ref.asmDesc)
) to importParamOffset + 2
}
}.first
fun setConstructorFunctionImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) = fun setConstructorFunctionImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.importFuncs.indices.fold(func) { func, importIndex -> ctx.importFuncs.indices.fold(func) { func, importIndex ->
@ -261,7 +285,10 @@ open class AstToAsm {
fun setConstructorTableImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) = fun setConstructorTableImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
if (ctx.mod.imports.none { it.kind is Node.Import.Kind.Table }) func else { if (ctx.mod.imports.none { it.kind is Node.Import.Kind.Table }) func else {
val importIndex = ctx.importFuncs.size + ctx.importGlobals.size + paramsBeforeImports + 1 val importIndex = ctx.importFuncs.size +
// Mutable global imports have setters and take up two spots
ctx.importGlobals.sumBy { if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == true) 2 else 1 } +
paramsBeforeImports + 1
func.addInsns( func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0), VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, importIndex), VarInsnNode(Opcodes.ALOAD, importIndex),
@ -299,11 +326,14 @@ open class AstToAsm {
global.type.contentType.typeRef, global.type.contentType.typeRef,
refGlobalKind.type.contentType.typeRef refGlobalKind.type.contentType.typeRef
) )
val paramOffset = ctx.importFuncs.size + paramsBeforeImports + 1 +
ctx.importGlobals.take(it.index).sumBy {
// Immutable jumps 1, mutable jumps 2
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1
else 2
}
listOf( listOf(
VarInsnNode( VarInsnNode(Opcodes.ALOAD, paramOffset),
Opcodes.ALOAD,
ctx.importFuncs.size + it.index + paramsBeforeImports + 1
),
MethodInsnNode( MethodInsnNode(
Opcodes.INVOKEVIRTUAL, Opcodes.INVOKEVIRTUAL,
MethodHandle::class.ref.asmName, MethodHandle::class.ref.asmName,
@ -356,7 +386,10 @@ open class AstToAsm {
// Otherwise, it was imported and we can set the elems on the imported one // Otherwise, it was imported and we can set the elems on the imported one
// from the parameter // from the parameter
// TODO: I think this is a security concern and bad practice, may revisit // TODO: I think this is a security concern and bad practice, may revisit
val importIndex = ctx.importFuncs.size + ctx.importGlobals.size + paramsBeforeImports + 1 val importIndex = ctx.importFuncs.size + ctx.importGlobals.sumBy {
// Immutable is 1, mutable is 2
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1 else 2
} + paramsBeforeImports + 1
return func.addInsns(VarInsnNode(Opcodes.ALOAD, importIndex)). return func.addInsns(VarInsnNode(Opcodes.ALOAD, importIndex)).
let { func -> addElemsToTable(ctx, func, paramsBeforeImports) }. let { func -> addElemsToTable(ctx, func, paramsBeforeImports) }.
// Remove the array that's still there // Remove the array that's still there
@ -532,28 +565,58 @@ open class AstToAsm {
is Either.Left -> (global.v.kind as Node.Import.Kind.Global).type is Either.Left -> (global.v.kind as Node.Import.Kind.Global).type
is Either.Right -> global.v.type is Either.Right -> global.v.type
} }
if (type.mutable) throw CompileErr.MutableGlobalExport(export.index)
// Create a simple getter // Create a simple getter
val method = MethodNode(Opcodes.ACC_PUBLIC, "get" + export.field.javaIdent.capitalize(), val getter = MethodNode(Opcodes.ACC_PUBLIC, "get" + export.field.javaIdent.capitalize(),
"()" + type.contentType.typeRef.asmDesc, null, null) "()" + type.contentType.typeRef.asmDesc, null, null)
method.addInsns(VarInsnNode(Opcodes.ALOAD, 0)) getter.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) method.addInsns( if (global is Either.Left) getter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName,
ctx.importGlobalGetterFieldName(export.index), MethodHandle::class.ref.asmDesc), ctx.importGlobalGetterFieldName(export.index), MethodHandle::class.ref.asmDesc),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact", MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact",
"()" + type.contentType.typeRef.asmDesc, false) "()" + type.contentType.typeRef.asmDesc, false)
) else method.addInsns( ) else getter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, ctx.globalName(export.index), FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, ctx.globalName(export.index),
type.contentType.typeRef.asmDesc) type.contentType.typeRef.asmDesc)
) )
method.addInsns(InsnNode(when (type.contentType) { getter.addInsns(InsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.IRETURN Node.Type.Value.I32 -> Opcodes.IRETURN
Node.Type.Value.I64 -> Opcodes.LRETURN Node.Type.Value.I64 -> Opcodes.LRETURN
Node.Type.Value.F32 -> Opcodes.FRETURN Node.Type.Value.F32 -> Opcodes.FRETURN
Node.Type.Value.F64 -> Opcodes.DRETURN Node.Type.Value.F64 -> Opcodes.DRETURN
})) }))
method.visibleAnnotations = listOf(exportAnnotation(export)) getter.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(method) ctx.cls.methods.plusAssign(getter)
// If mutable, create simple setter
if (type.mutable) {
val setter = MethodNode(Opcodes.ACC_PUBLIC, "set" + export.field.javaIdent.capitalize(),
"(${type.contentType.typeRef.asmDesc})V", null, null)
setter.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) setter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName,
ctx.importGlobalSetterFieldName(export.index), MethodHandle::class.ref.asmDesc),
VarInsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.ILOAD
Node.Type.Value.I64 -> Opcodes.LLOAD
Node.Type.Value.F32 -> Opcodes.FLOAD
Node.Type.Value.F64 -> Opcodes.DLOAD
}, 1),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact",
"(${type.contentType.typeRef.asmDesc})V", false),
InsnNode(Opcodes.RETURN)
) else setter.addInsns(
VarInsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.ILOAD
Node.Type.Value.I64 -> Opcodes.LLOAD
Node.Type.Value.F32 -> Opcodes.FLOAD
Node.Type.Value.F64 -> Opcodes.DLOAD
}, 1),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName, ctx.globalName(export.index),
type.contentType.typeRef.asmDesc),
InsnNode(Opcodes.RETURN)
)
setter.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(setter)
}
} }
fun addExportMemory(ctx: ClsContext, export: Node.Export) { fun addExportMemory(ctx: ClsContext, export: Node.Export) {

View File

@ -102,18 +102,6 @@ sealed class CompileErr(message: String, cause: Throwable? = null) : RuntimeExce
override val asmErrString get() = "global is immutable" override val asmErrString get() = "global is immutable"
} }
class MutableGlobalImport(
val index: Int
) : CompileErr("Attempted to import mutable global at index $index") {
override val asmErrString get() = "mutable globals cannot be imported"
}
class MutableGlobalExport(
val index: Int
) : CompileErr("Attempted to export global $index which is mutable") {
override val asmErrString get() = "mutable globals cannot be exported"
}
class GlobalInitNotConstant( class GlobalInitNotConstant(
val index: Int val index: Int
) : CompileErr("Expected init for global $index to be single constant value") { ) : CompileErr("Expected init for global $index to be single constant value") {

View File

@ -116,9 +116,9 @@ interface Module {
} }
// Global imports // Global imports
val globalImports = mod.imports.mapNotNull { val globalImports = mod.imports.flatMap {
if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobal(it, it.kind.type) if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobals(it, it.kind.type)
else null else emptyList()
} }
constructorParams += globalImports constructorParams += globalImports

View File

@ -55,4 +55,12 @@ sealed class RunErr(message: String, cause: Throwable? = null) : RuntimeExceptio
override val asmErrString get() = "unknown import" override val asmErrString get() = "unknown import"
override val asmErrStrings get() = listOf(asmErrString, "incompatible import type") override val asmErrStrings get() = listOf(asmErrString, "incompatible import type")
} }
class ImportGlobalInvalidMutability(
val module: String,
val field: String,
val expected: Boolean
) : RunErr("Expected imported global $module::$field to have mutability as ${!expected}") {
override val asmErrString get() = "incompatible import type"
}
} }

View File

@ -263,10 +263,12 @@ data class ScriptContext(
return Module.Compiled(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem) return Module.Compiled(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem)
} }
fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType): MethodHandle { fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType) = bindImport(
import, if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent, methodType)
fun bindImport(import: Node.Import, javaName: String, methodType: MethodType): MethodHandle {
// Find a method that matches our expectations // Find a method that matches our expectations
val module = registrations[import.module] ?: throw RunErr.ImportNotFound(import.module, import.field) val module = registrations[import.module] ?: throw RunErr.ImportNotFound(import.module, import.field)
val javaName = if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent
val kind = when (import.kind) { val kind = when (import.kind) {
is Node.Import.Kind.Func -> WasmExternalKind.FUNCTION is Node.Import.Kind.Func -> WasmExternalKind.FUNCTION
is Node.Import.Kind.Table -> WasmExternalKind.TABLE is Node.Import.Kind.Table -> WasmExternalKind.TABLE
@ -281,8 +283,18 @@ data class ScriptContext(
bindImport(import, false, bindImport(import, false,
MethodType.methodType(funcType.ret?.jclass ?: Void.TYPE, funcType.params.map { it.jclass })) MethodType.methodType(funcType.ret?.jclass ?: Void.TYPE, funcType.params.map { it.jclass }))
fun resolveImportGlobal(import: Node.Import, globalType: Node.Type.Global) = fun resolveImportGlobals(import: Node.Import, globalType: Node.Type.Global): List<MethodHandle> {
bindImport(import, true, MethodType.methodType(globalType.contentType.jclass)) val getter = bindImport(import, true, MethodType.methodType(globalType.contentType.jclass))
// Whether the setter is present or not defines whether it is mutable
val setter = try {
bindImport(import, "set" + import.field.javaIdent.capitalize(),
MethodType.methodType(Void.TYPE, globalType.contentType.jclass))
} catch (e: RunErr.ImportNotFound) { null }
// Mutability must match
if (globalType.mutable == (setter == null))
throw RunErr.ImportGlobalInvalidMutability(import.module, import.field, globalType.mutable)
return if (setter == null) listOf(getter) else listOf(getter, setter)
}
fun resolveImportMemory(import: Node.Import, memoryType: Node.Type.Memory, mem: Mem) = fun resolveImportMemory(import: Node.Import, memoryType: Node.Type.Memory, mem: Mem) =
bindImport(import, true, MethodType.methodType(Class.forName(mem.memType.asm.className))). bindImport(import, true, MethodType.methodType(Class.forName(mem.memType.asm.className))).

View File

@ -13,7 +13,7 @@ class SpecTestUnit(name: String, wast: String, expectedOutput: String?) : BaseTe
override val shouldFail get() = name.endsWith(".fail") override val shouldFail get() = name.endsWith(".fail")
override val defaultMaxMemPages get() = when (name) { override val defaultMaxMemPages get() = when (name) {
"nop"-> 20 "nop" -> 20
"resizing" -> 830 "resizing" -> 830
"imports" -> 5 "imports" -> 5
else -> 1 else -> 1