More work on conditionals

This commit is contained in:
Chad Retz 2017-03-26 16:40:30 -05:00
parent 6014fd65f8
commit d4ca5885d3
13 changed files with 230 additions and 113 deletions

View File

@ -15,6 +15,9 @@ buildscript {
apply plugin: 'java'
apply plugin: 'kotlin'
apply plugin: 'application'
mainClassName = "asmble.cli.MainKt"
repositories {
mavenCentral()

View File

@ -9,7 +9,7 @@ data class Script(val commands: List<Cmd>) {
data class Get(val name: String?, val string: String): Action()
}
sealed class Assertion: Cmd() {
data class Return(val action: Action, val exprs: List<Node.Instr>): Assertion()
data class Return(val action: Action, val exprs: List<List<Node.Instr>>): Assertion()
data class ReturnNan(val action: Action): Assertion()
data class Trap(val action: Action, val failure: String): Assertion()
data class Malformed(val module: Node.Module, val failure: String): Assertion()

View File

@ -0,0 +1,5 @@
package asmble.cli
fun main(args: Array<String>) {
TODO("CLI")
}

View File

@ -38,6 +38,11 @@ fun <T : Exception> KClass<T>.athrow(msg: String) = listOf(
Void::class.ref.asMethodRetDesc(String::class.ref), false)
)
// Ug: https://youtrack.jetbrains.com/issue/KT-17064
fun KClass<*>.invokeStatic(name: String, retType: KClass<*>, vararg params: KClass<*>) =
MethodInsnNode(Opcodes.INVOKESTATIC, this.javaObjectType.ref.asmName, name,
retType.ref.asMethodRetDesc(*params.map { it.ref }.toTypedArray()), false)
val Class<*>.ref: TypeRef get() = TypeRef(Type.getType(this))
val Class<*>.valueType: Node.Type.Value? get() = when (this) {

View File

@ -74,19 +74,19 @@ data class Func(
}
}
fun pushBlock(insn: Node.Instr) = copy(blockStack = blockStack + Block.NoLabel(insn, insns.size))
fun pushBlock(insn: Node.Instr) = copy(blockStack = blockStack + Block(insn, insns.size, stack))
fun popBlock() = copy(blockStack = blockStack.dropLast(1)) to blockStack.last()
fun blockAtDepth(depth: Int) = blockStack[blockStack.size - depth].let { block ->
fun blockAtDepth(depth: Int) = blockStack[blockStack.size - depth - 1].let { block ->
when (block) {
is Block.WithLabel -> this to block
// We have to lazily create it here
is Block.NoLabel -> blockStack.toMutableList().let {
else -> blockStack.toMutableList().let {
val newBlock = block.withLabel(LabelNode())
it[blockStack.size - depth] = newBlock
it[blockStack.size - depth - 1] = newBlock
copy(blockStack = it) to newBlock
}
is Block.WithLabel -> this to block
}
}
@ -96,25 +96,23 @@ data class Func(
fun popIf() = copy(ifStack = ifStack.dropLast(1)) to peekIf()
sealed class Block {
abstract val insn: Node.Instr
abstract val startIndex: Int
abstract val maybeLabel: LabelNode?
open class Block(
val insn: Node.Instr,
val startIndex: Int,
val origStack: List<TypeRef>
) {
open val label: LabelNode? get() = null
open val blockExitVals: List<TypeRef?> = emptyList()
fun withLabel(label: LabelNode) = WithLabel(insn, startIndex, origStack, label)
val insnType: Node.Type.Value? get() = (insn as? Node.Instr.Args.Type)?.type
data class NoLabel(
override val insn: Node.Instr,
override val startIndex: Int
) : Block() {
override val maybeLabel: LabelNode? get() = null
fun withLabel(label: LabelNode) = WithLabel(insn, startIndex, label)
}
data class WithLabel(
override val insn: Node.Instr,
override val startIndex: Int,
val label: LabelNode
) : Block() {
override val maybeLabel: LabelNode? get() = label
class WithLabel(
insn: Node.Instr,
startIndex: Int,
origStack: List<TypeRef>,
override val label: LabelNode
) : Block(insn, startIndex, origStack) {
override var blockExitVals: List<TypeRef?> = emptyList()
}
}
}

View File

@ -100,13 +100,9 @@ open class FuncBuilder {
is Node.Instr.End ->
applyEnd(ctx, fn)
is Node.Instr.Br ->
fn.blockAtDepth(i.relativeDepth).let { (fn, block) ->
fn.addInsns(JumpInsnNode(Opcodes.GOTO, block.label))
}
applyBr(ctx, fn, i)
is Node.Instr.BrIf ->
fn.blockAtDepth(i.relativeDepth).let { (fn, block) ->
fn.addInsns(JumpInsnNode(Opcodes.IFNE, block.label))
}
applyBrIf(ctx, fn, i)
is Node.Instr.BrTable ->
applyBrTable(ctx, fn, i)
is Node.Instr.Return ->
@ -223,11 +219,11 @@ open class FuncBuilder {
applyF64Cmp(ctx, fn, Opcodes.IFGE)
is Node.Instr.I32Clz ->
// TODO Should make unsigned?
applyI32Unary(ctx, fn, Integer::numberOfLeadingZeros.invokeStatic())
applyI32Unary(ctx, fn, Integer::class.invokeStatic("numberOfLeadingZeros", Int::class, Int::class))
is Node.Instr.I32Ctz ->
applyI32Unary(ctx, fn, Integer::numberOfTrailingZeros.invokeStatic())
applyI32Unary(ctx, fn, Integer::class.invokeStatic("numberOfTrailingZeros", Int::class, Int::class))
is Node.Instr.I32Popcnt ->
applyI32Unary(ctx, fn, Integer::bitCount.invokeStatic())
applyI32Unary(ctx, fn, Integer::class.invokeStatic("bitCount", Int::class, Int::class))
is Node.Instr.I32Add ->
applyI32Binary(ctx, fn, Opcodes.IADD)
is Node.Instr.I32Sub ->
@ -237,11 +233,11 @@ open class FuncBuilder {
is Node.Instr.I32DivS ->
applyI32Binary(ctx, fn, Opcodes.IDIV)
is Node.Instr.I32DivU ->
applyI32Binary(ctx, fn, Integer::divideUnsigned.invokeStatic())
applyI32Binary(ctx, fn, Integer::class.invokeStatic("divideUnsigned", Int::class, Int::class, Int::class))
is Node.Instr.I32RemS ->
applyI32Binary(ctx, fn, Opcodes.IREM)
is Node.Instr.I32RemU ->
applyI32Binary(ctx, fn, Integer::remainderUnsigned.invokeStatic())
applyI32Binary(ctx, fn, Integer::class.invokeStatic("remainderUnsigned", Int::class, Int::class, Int::class))
is Node.Instr.I32And ->
applyI32Binary(ctx, fn, Opcodes.IAND)
is Node.Instr.I32Or ->
@ -255,15 +251,15 @@ open class FuncBuilder {
is Node.Instr.I32ShrU ->
applyI32Binary(ctx, fn, Opcodes.IUSHR)
is Node.Instr.I32Rotl ->
applyI32Binary(ctx, fn, Integer::rotateLeft.invokeStatic())
applyI32Binary(ctx, fn, Integer::class.invokeStatic("rotateLeft", Int::class, Int::class, Int::class))
is Node.Instr.I32Rotr ->
applyI32Binary(ctx, fn, Integer::rotateRight.invokeStatic())
applyI32Binary(ctx, fn, Integer::class.invokeStatic("rotateRight", Int::class, Int::class, Int::class))
is Node.Instr.I64Clz ->
applyI64Unary(ctx, fn, java.lang.Long::numberOfLeadingZeros.invokeStatic())
applyI64Unary(ctx, fn, java.lang.Long::class.invokeStatic("numberOfLeadingZeros", Int::class, Long::class))
is Node.Instr.I64Ctz ->
applyI64Unary(ctx, fn, java.lang.Long::numberOfTrailingZeros.invokeStatic())
applyI64Unary(ctx, fn, java.lang.Long::class.invokeStatic("numberOfTrailingZeros", Int::class, Long::class))
is Node.Instr.I64Popcnt ->
applyI64Unary(ctx, fn, java.lang.Long::bitCount.invokeStatic())
applyI64Unary(ctx, fn, java.lang.Long::class.invokeStatic("bitCount", Int::class, Long::class))
is Node.Instr.I64Add ->
applyI64Binary(ctx, fn, Opcodes.LADD)
is Node.Instr.I64Sub ->
@ -273,11 +269,13 @@ open class FuncBuilder {
is Node.Instr.I64DivS ->
applyI64Binary(ctx, fn, Opcodes.LDIV)
is Node.Instr.I64DivU ->
applyI64Binary(ctx, fn, java.lang.Long::divideUnsigned.invokeStatic())
applyI64Binary(ctx, fn, java.lang.Long::class.invokeStatic("divideUnsigned",
Long::class, Long::class, Long::class))
is Node.Instr.I64RemS ->
applyI64Binary(ctx, fn, Opcodes.LREM)
is Node.Instr.I64RemU ->
applyI64Binary(ctx, fn, java.lang.Long::remainderUnsigned.invokeStatic())
applyI64Binary(ctx, fn, java.lang.Long::class.invokeStatic("remainderUnsigned",
Long::class, Long::class, Long::class))
is Node.Instr.I64And ->
applyI64Binary(ctx, fn, Opcodes.LAND)
is Node.Instr.I64Or ->
@ -291,9 +289,11 @@ open class FuncBuilder {
is Node.Instr.I64ShrU ->
applyI64Binary(ctx, fn, Opcodes.LUSHR)
is Node.Instr.I64Rotl ->
applyI64Binary(ctx, fn, java.lang.Long::rotateLeft.invokeStatic())
applyI64Binary(ctx, fn, java.lang.Long::class.invokeStatic("rotateLeft",
Long::class, Long::class, Int::class))
is Node.Instr.I64Rotr ->
applyI64Binary(ctx, fn, java.lang.Long::rotateRight.invokeStatic())
applyI64Binary(ctx, fn, java.lang.Long::class.invokeStatic("rotateRight",
Long::class, Long::class, Int::class))
is Node.Instr.F32Abs ->
applyF32Unary(ctx, fn, forceFnType<(Float) -> Float>(Math::abs).invokeStatic())
is Node.Instr.F32Neg ->
@ -361,16 +361,17 @@ open class FuncBuilder {
is Node.Instr.I32TruncUF32 ->
// TODO: wat?
applyConv(ctx, fn, Float::class.ref, Int::class.ref, Opcodes.F2I).
addInsns(java.lang.Short::toUnsignedInt.invokeStatic())
addInsns(java.lang.Short::class.invokeStatic("toUnsignedInt", Int::class, Short::class))
is Node.Instr.I32TruncSF64 ->
applyConv(ctx, fn, Double::class.ref, Int::class.ref, Opcodes.D2I)
is Node.Instr.I32TruncUF64 ->
applyConv(ctx, fn, Double::class.ref, Int::class.ref, Opcodes.D2I).
addInsns(java.lang.Short::toUnsignedInt.invokeStatic())
addInsns(java.lang.Short::class.invokeStatic("toUnsignedInt", Int::class, Short::class))
is Node.Instr.I64ExtendSI32 ->
applyConv(ctx, fn, Int::class.ref, Long::class.ref, Opcodes.I2L)
is Node.Instr.I64ExtendUI32 ->
applyConv(ctx, fn, Int::class.ref, Long::class.ref, Integer::toUnsignedLong.invokeStatic())
applyConv(ctx, fn, Int::class.ref, Long::class.ref,
Integer::class.invokeStatic("toUnsignedLong", Long::class, Int::class))
is Node.Instr.I64TruncSF32 ->
applyConv(ctx, fn, Float::class.ref, Long::class.ref, Opcodes.F2L)
is Node.Instr.I64TruncUF32 ->
@ -385,7 +386,7 @@ open class FuncBuilder {
is Node.Instr.F32ConvertSI32 ->
applyConv(ctx, fn, Int::class.ref, Float::class.ref, Opcodes.I2F)
is Node.Instr.F32ConvertUI32 ->
fn.addInsns(Integer::toUnsignedLong.invokeStatic()).
fn.addInsns(Integer::class.invokeStatic("toUnsignedLong", Long::class, Int::class)).
let { applyConv(ctx, it, Int::class.ref, Float::class.ref, Opcodes.L2F) }
is Node.Instr.F32ConvertSI64 ->
applyConv(ctx, fn, Long::class.ref, Float::class.ref, Opcodes.L2F)
@ -397,7 +398,7 @@ open class FuncBuilder {
is Node.Instr.F64ConvertSI32 ->
applyConv(ctx, fn, Int::class.ref, Double::class.ref, Opcodes.I2D)
is Node.Instr.F64ConvertUI32 ->
fn.addInsns(Integer::toUnsignedLong.invokeStatic()).
fn.addInsns(Integer::class.invokeStatic("toUnsignedLong", Long::class, Int::class)).
let { applyConv(ctx, it, Int::class.ref, Double::class.ref, Opcodes.L2D) }
is Node.Instr.F64ConvertSI64 ->
applyConv(ctx, fn, Long::class.ref, Double::class.ref, Opcodes.L2D)
@ -407,23 +408,51 @@ open class FuncBuilder {
is Node.Instr.F64PromoteF32 ->
applyConv(ctx, fn, Float::class.ref, Double::class.ref, Opcodes.F2D)
is Node.Instr.I32ReinterpretF32 ->
applyConv(ctx, fn, Float::class.ref, Int::class.ref, java.lang.Float::floatToRawIntBits.invokeStatic())
applyConv(ctx, fn, Float::class.ref, Int::class.ref,
java.lang.Float::class.invokeStatic("floatToRawIntBits", Int::class, Float::class))
is Node.Instr.I64ReinterpretF64 ->
applyConv(ctx, fn, Double::class.ref, Long::class.ref, java.lang.Double::doubleToRawLongBits.invokeStatic())
applyConv(ctx, fn, Double::class.ref, Long::class.ref,
java.lang.Double::class.invokeStatic("doubleToRawLongBits", Long::class, Double::class))
is Node.Instr.F32ReinterpretI32 ->
applyConv(ctx, fn, Int::class.ref, Float::class.ref, java.lang.Float::intBitsToFloat.invokeStatic())
applyConv(ctx, fn, Int::class.ref, Float::class.ref,
java.lang.Float::class.invokeStatic("intBitsToFloat", Float::class, Int::class))
is Node.Instr.F64ReinterpretI64 ->
applyConv(ctx, fn, Double::class.ref, Long::class.ref, java.lang.Double::longBitsToDouble.invokeStatic())
applyConv(ctx, fn, Double::class.ref, Long::class.ref,
java.lang.Double::class.invokeStatic("longBitsToDouble", Double::class, Long::class))
}
fun applyBr(ctx: FuncContext, fn: Func, i: Node.Instr.Br) =
fn.blockAtDepth(i.relativeDepth).let { (fn, block) ->
fn.addInsns(JumpInsnNode(Opcodes.GOTO, block.label)).let { fn ->
block.insnType?.typeRef?.let { typ ->
fn.popExpecting(typ).also { block.blockExitVals += typ }
} ?: fn
}
}
fun applyBrIf(ctx: FuncContext, fn: Func, i: Node.Instr.BrIf) =
fn.blockAtDepth(i.relativeDepth).let { (fn, block) ->
fn.popExpecting(Int::class.ref).addInsns(JumpInsnNode(Opcodes.IFNE, block.label)).let { fn ->
// We don't have to pop this like we do with br, because it's conditional
block.insnType?.typeRef?.let { block.blockExitVals += it }
fn
}
}
// Can compile quite cleanly as a table switch on the JVM
fun applyBrTable(ctx: FuncContext, fn: Func, insn: Node.Instr.BrTable) =
fn.blockAtDepth(insn.default).let { (fn, defaultBlock) ->
defaultBlock.insnType?.typeRef?.let { defaultBlock.blockExitVals += it }
insn.targetTable.fold(fn to emptyList<LabelNode>()) { (fn, labels), targetDepth ->
fn.blockAtDepth(targetDepth).let { (fn, targetBlock) -> fn to (labels + targetBlock.label) }
fn.blockAtDepth(targetDepth).let { (fn, targetBlock) ->
targetBlock.insnType?.typeRef?.let { targetBlock.blockExitVals += it }
fn to (labels + targetBlock.label)
}
}.let { (fn, targetLabels) ->
fn.addInsns(TableSwitchInsnNode(0, targetLabels.size - 1,
fn.popExpecting(Int::class.ref).addInsns(TableSwitchInsnNode(0, targetLabels.size - 1,
defaultBlock.label, *targetLabels.toTypedArray()))
}.let { fn ->
defaultBlock.insnType?.typeRef?.let { fn.popExpecting(it) } ?: fn
}
}
@ -435,30 +464,52 @@ open class FuncBuilder {
}
fun applyEnd(ctx: FuncContext, fn: Func) = fn.popBlock().let { (fn, block) ->
when (block.insn) {
is Node.Instr.Block ->
// Add label to end of block if it's there
block.maybeLabel?.let { fn.addInsns(it) } ?: fn
is Node.Instr.Loop ->
// Add label to beginning of loop if it's there
block.maybeLabel?.let { fn.copy(insns = fn.insns.add(block.startIndex, it)) } ?: fn
is Node.Instr.If -> fn.popIf().let { (fn, jumpNode) ->
when (block.maybeLabel) {
// If there is no existing break label, add one to initial
// "if" only if it isn't there from an "else"
null -> if (jumpNode.label != null) fn else {
jumpNode.label = LabelNode()
fn.addInsns(jumpNode.label)
}
// If there is one, add it to the initial "if"
// if the "else" didn't set one on there...then push it
else -> {
if (jumpNode.label == null) jumpNode.label = block.maybeLabel
fn.addInsns(block.maybeLabel!!)
// Go over each exit and make sure it did the right thing
block.blockExitVals.forEach {
require(it == block.insnType?.typeRef) { "Block exit val was $it, expected ${block.insnType}" }
}
// We need to check the current stack
when (block.insnType) {
null -> {
require(fn.stack == block.origStack) {
"At block end, expected stack ${block.origStack}, got ${fn.stack}"
}
fn
}
else -> {
val typ = block.insnType!!.typeRef
require(fn.stack == block.origStack || fn.stack == block.origStack + typ) {
"At block end, expected stack ${block.origStack} and maybe $typ, got ${fn.stack}"
}
// We have to add the expected type ourselves, there wasn't a fall through...
if (fn.stack.size == block.origStack.size) fn.push(typ) else fn
}
}.let { fn ->
when (block.insn) {
is Node.Instr.Block ->
// Add label to end of block if it's there
block.label?.let { fn.addInsns(it) } ?: fn
is Node.Instr.Loop ->
// Add label to beginning of loop if it's there
block.label?.let { fn.copy(insns = fn.insns.add(block.startIndex, it)) } ?: fn
is Node.Instr.If -> fn.popIf().let { (fn, jumpNode) ->
when (block.label) {
// If there is no existing break label, add one to initial
// "if" only if it isn't there from an "else"
null -> if (jumpNode.label != null) fn else {
jumpNode.label = LabelNode()
fn.addInsns(jumpNode.label)
}
// If there is one, add it to the initial "if"
// if the "else" didn't set one on there...then push it
else -> {
if (jumpNode.label == null) jumpNode.label = block.label
fn.addInsns(block.label!!)
}
}
}
else -> error("Unrecognized end for ${block.insn}")
}
else -> error("Unrecognized end for ${block.insn}")
}
}
@ -571,7 +622,7 @@ open class FuncBuilder {
// TODO: test whether we need FCMPG instead
addInsns(InsnNode(Opcodes.FCMPL)).
push(Int::class.ref).
let { applyI32UnaryCmp(ctx, fn, op) }
let { fn -> applyI32UnaryCmp(ctx, fn, op) }
fun applyF64Cmp(ctx: FuncContext, fn: Func, op: Int) =
fn.popExpecting(Double::class.ref).
@ -579,13 +630,15 @@ open class FuncBuilder {
// TODO: test whether we need DCMPG instead
addInsns(InsnNode(Opcodes.DCMPL)).
push(Int::class.ref).
let { applyI32UnaryCmp(ctx, fn, op) }
let { fn -> applyI32UnaryCmp(ctx, fn, op) }
fun applyI64CmpU(ctx: FuncContext, fn: Func, op: Int) =
applyCmpU(ctx, fn, op, Long::class.ref, java.lang.Long::compareUnsigned.invokeStatic())
applyCmpU(ctx, fn, op, Long::class.ref,
java.lang.Long::class.invokeStatic("compareUnsigned", Int::class, Long::class, Long::class))
fun applyI32CmpU(ctx: FuncContext, fn: Func, op: Int) =
applyCmpU(ctx, fn, op, Int::class.ref, Integer::compareUnsigned.invokeStatic())
applyCmpU(ctx, fn, op, Int::class.ref,
Integer::class.invokeStatic("compareUnsigned", Int::class, Int::class, Int::class))
fun applyCmpU(ctx: FuncContext, fn: Func, op: Int, inTypes: TypeRef, meth: MethodInsnNode) =
// Call the method, then compare with 0
@ -593,14 +646,14 @@ open class FuncBuilder {
popExpecting(inTypes).
addInsns(meth).
push(Int::class.ref).
let { applyI32UnaryCmp(ctx, it, op) }
let { fn -> applyI32UnaryCmp(ctx, fn, op) }
fun applyI64CmpS(ctx: FuncContext, fn: Func, op: Int) =
fn.popExpecting(Long::class.ref).
popExpecting(Long::class.ref).
addInsns(InsnNode(Opcodes.LCMP)).
push(Int::class.ref).
let { applyI32UnaryCmp(ctx, fn, op) }
let { fn -> applyI32UnaryCmp(ctx, fn, op) }
fun applyI32CmpS(ctx: FuncContext, fn: Func, op: Int) = applyCmpS(ctx, fn, op, Int::class.ref)

View File

@ -27,13 +27,14 @@ open class InsnReworker {
var insnsToInject = emptyMap<Int, List<Insn>>()
fun injectBeforeLastStackCount(insn: Insn, count: Int) {
ctx.trace { "Injecting $insn back $count stack values" }
fun inject(index: Int) {
insnsToInject += index to (insnsToInject[index]?.let { listOf(insn) + it } ?: listOf(insn))
}
if (count == 0) return inject(stackManips.size)
var countSoFar = 0
for ((amountChanged, insnIndex) in stackManips.asReversed()) {
countSoFar += amountChanged
if (countSoFar == count) {
insnsToInject += insnIndex to (insnsToInject[insnIndex]?.let { listOf(insn) + it } ?: listOf(insn))
return
}
if (countSoFar == count) return inject(insnIndex)
}
error("Unable to find place to inject $insn")
}
@ -87,6 +88,10 @@ open class InsnReworker {
}
fun insnStackDiff(ctx: ClsContext, insn: Node.Instr) = when (insn) {
is Node.Instr.Unreachable, is Node.Instr.Nop, is Node.Instr.Block,
is Node.Instr.Loop, is Node.Instr.If, is Node.Instr.Else,
is Node.Instr.End, is Node.Instr.Br, is Node.Instr.BrIf,
is Node.Instr.BrTable, is Node.Instr.Return -> NOP
is Node.Instr.Call -> ctx.funcTypeAtIndex(insn.index).let {
// All calls pop "this" + params, and any return is a push
POP_THIS + (POP_PARAM + it.params.size) + (if (it.ret == null) NOP else PUSH_RESULT)
@ -151,7 +156,6 @@ open class InsnReworker {
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64, is Node.Instr.F64PromoteF32,
is Node.Instr.I32ReinterpretF32, is Node.Instr.I64ReinterpretF64, is Node.Instr.F32ReinterpretI32,
is Node.Instr.F64ReinterpretI64 -> POP_PARAM + PUSH_RESULT
else -> TODO()
}
fun nonAdjacentMemAccesses(insns: List<Insn>) = insns.fold(0 to false) { (count, lastCouldHaveMem), insn ->

View File

@ -12,14 +12,22 @@ open class AstToSExpr {
}
fun fromAssertion(v: Script.Cmd.Assertion) = when(v) {
is Script.Cmd.Assertion.Return -> newMulti("assert_return") + fromAction(v.action) + fromInstrs(v.exprs)
is Script.Cmd.Assertion.ReturnNan -> newMulti("assert_return_nan") + fromAction(v.action)
is Script.Cmd.Assertion.Trap -> newMulti("assert_trap") + fromAction(v.action) + v.failure
is Script.Cmd.Assertion.Malformed -> newMulti("assert_malformed") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.Invalid -> newMulti("assert_invalid") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.SoftInvalid -> newMulti("assert_soft_invalid") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.Unlinkable -> newMulti("assert_unlinkable") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.TrapModule -> newMulti("assert_trap") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.Return ->
newMulti("assert_return") + fromAction(v.action) + v.exprs.flatMap(this::fromInstrs)
is Script.Cmd.Assertion.ReturnNan ->
newMulti("assert_return_nan") + fromAction(v.action)
is Script.Cmd.Assertion.Trap ->
newMulti("assert_trap") + fromAction(v.action) + v.failure
is Script.Cmd.Assertion.Malformed ->
newMulti("assert_malformed") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.Invalid ->
newMulti("assert_invalid") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.SoftInvalid ->
newMulti("assert_soft_invalid") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.Unlinkable ->
newMulti("assert_unlinkable") + fromModule(v.module) + v.failure
is Script.Cmd.Assertion.TrapModule ->
newMulti("assert_trap") + fromModule(v.module) + v.failure
}
fun fromCmd(v: Script.Cmd): SExpr.Multi = when(v) {

View File

@ -33,7 +33,7 @@ open class SExprToAst {
return when(exp.vals.first().symbolStr()) {
"assert_return" ->
Script.Cmd.Assertion.Return(toAction(mult),
exp.vals.drop(2).flatMap { toExprMaybe(it as SExpr.Multi, ExprContext(emptyMap())) })
exp.vals.drop(2).map { toExprMaybe(it as SExpr.Multi, ExprContext(emptyMap())) })
"assert_return_nan" ->
Script.Cmd.Assertion.ReturnNan(toAction(mult))
"assert_trap" ->
@ -525,6 +525,7 @@ open class SExprToAst {
fun toOpMaybe(exp: SExpr.Multi, offset: Int, ctx: ExprContext): Pair<Node.Instr, Int>? {
if (offset >= exp.vals.size) return null
val head = exp.vals[offset].symbol()!!
fun varIsStringRef() = exp.vals[offset + 1].symbolStr()?.firstOrNull() == '$'
fun oneVar() = toVar(exp.vals[offset + 1].symbol()!!, ctx.nameMap)
val op = InstrOp.strToOpMap[head.contents]
return when(op) {
@ -532,12 +533,14 @@ open class SExprToAst {
is InstrOp.ControlFlowOp.NoArg -> Pair(op.create, 1)
is InstrOp.ControlFlowOp.TypeArg -> return null // Type not handled here
is InstrOp.ControlFlowOp.DepthArg -> {
// Depth is special, because we actually subtract from our current depth
Pair(op.create(ctx.blockDepth - oneVar()), 2)
// Named depth is special, because we actually subtract from our current depth
if (varIsStringRef()) Pair(op.create(ctx.blockDepth - oneVar()), 2)
else Pair(op.create(oneVar()), 2)
}
is InstrOp.ControlFlowOp.TableArg -> {
require(!varIsStringRef()) { "String refs not supported in br_table yet" }
val vars = exp.vals.drop(offset + 1).takeUntilNullLazy { toVarMaybe(it, ctx.nameMap) }
Pair(op.create(vars.dropLast(1), vars.last()), offset + vars.size)
Pair(op.create(vars.dropLast(1), vars.last()), offset + 1 + vars.size)
}
is InstrOp.CallOp.IndexArg -> Pair(op.create(oneVar()), 2)
is InstrOp.CallOp.IndexReservedArg -> Pair(op.create(oneVar(), false), 2)
@ -562,8 +565,8 @@ open class SExprToAst {
Pair(op.create(instrAlign, instrOffset), count)
}
is InstrOp.MemOp.ReservedArg -> Pair(op.create(false), 1)
is InstrOp.ConstOp.IntArg -> Pair(op.create(exp.vals[offset + 1].symbol()!!.contents.toInt()), 2)
is InstrOp.ConstOp.LongArg -> Pair(op.create(exp.vals[offset + 1].symbol()!!.contents.toLong()), 2)
is InstrOp.ConstOp.IntArg -> Pair(op.create(exp.vals[offset + 1].symbol()!!.contents.toIntConst()), 2)
is InstrOp.ConstOp.LongArg -> Pair(op.create(exp.vals[offset + 1].symbol()!!.contents.toLongConst()), 2)
is InstrOp.ConstOp.FloatArg -> Pair(op.create(exp.vals[offset + 1].symbol()!!.contents.toFloat()), 2)
is InstrOp.ConstOp.DoubleArg -> Pair(op.create(exp.vals[offset + 1].symbol()!!.contents.toDouble()), 2)
is InstrOp.CompareOp.NoArg -> Pair(op.create, 1)
@ -662,6 +665,11 @@ open class SExprToAst {
}
}
private fun String.toIntConst() =
if (this.startsWith("0x")) this.substring(2).toInt(16) else this.toInt()
private fun String.toLongConst() =
if (this.startsWith("0x")) this.substring(2).toLong(16) else this.toLong()
private fun SExpr.requireSymbol(contents: String, quotedCheck: Boolean? = null) {
if (this is SExpr.Symbol && this.contents == contents &&
(quotedCheck == null || this.quoted == quotedCheck)) {

View File

@ -18,8 +18,8 @@ open class StrToSExpr {
ret += state.nextSExpr() ?: break
if (state.err != null) return ParseResult.Error(str.posFromOffset(state.offset), state.err!!)
}
val retVals = if (ret.size == 1 && ret[0] is SExpr.Multi) (ret[0] as SExpr.Multi).vals else ret
return ParseResult.Success(retVals, state.exprOffsetMap)
// val retVals = if (ret.size == 1 && ret[0] is SExpr.Multi) (ret[0] as SExpr.Multi).vals else ret
return ParseResult.Success(ret, state.exprOffsetMap)
}
private class ParseState(

View File

@ -17,7 +17,8 @@ data class ScriptContext(
val registrations: Map<String, Module> = emptyMap(),
val logger: Logger = Logger.Print(Logger.Level.OFF),
val adjustContext: (ClsContext) -> ClsContext = { it },
val classLoader: SimpleClassLoader = ScriptContext.SimpleClassLoader(ScriptContext::class.java.classLoader),
val classLoader: SimpleClassLoader =
ScriptContext.SimpleClassLoader(ScriptContext::class.java.classLoader, logger),
val exceptionTranslator: ExceptionTranslator = ExceptionTranslator
) {
fun withHarnessRegistered(out: PrintWriter = PrintWriter(System.out, true)) =
@ -37,11 +38,26 @@ data class ScriptContext(
fun doAssertion(cmd: Script.Cmd.Assertion) {
when (cmd) {
is Script.Cmd.Assertion.Return -> assertReturn(cmd)
is Script.Cmd.Assertion.Trap -> assertTrap(cmd)
is Script.Cmd.Assertion.Invalid -> assertInvalid(cmd)
else -> TODO()
}
}
fun assertReturn(ret: Script.Cmd.Assertion.Return) {
require(ret.exprs.size < 2)
val (retType, retVal) = doAction(ret.action)
when (retType) {
null -> if (ret.exprs.isNotEmpty()) throw AssertionError("Got empty return, expected not empty")
else -> {
if (ret.exprs.isEmpty()) throw AssertionError("Got return, expected empty")
val expectedVal = runExpr(ret.exprs.first(), retType)
if (retVal != expectedVal) throw AssertionError("Expected $expectedVal, got $retVal")
}
}
}
fun assertTrap(trap: Script.Cmd.Assertion.Trap) {
try { doAction(trap.action).also { throw AssertionError("Expected exception") } }
catch (e: Throwable) {
@ -52,6 +68,19 @@ data class ScriptContext(
}
}
fun assertInvalid(invalid: Script.Cmd.Assertion.Invalid) {
try {
val className = "invalid" + UUID.randomUUID().toString().replace("-", "")
compileModule(invalid.module, className, null)
throw AssertionError("Expected invalid module with error '${invalid.failure}', was valid")
} catch (e: Throwable) {
val innerEx = if (e is InvocationTargetException) e.targetException else e
exceptionTranslator.translateOrRethrow(innerEx).let {
if (it != invalid.failure) throw AssertionError("Expected invalid '${invalid.failure}' got '$it'")
}
}
}
fun doAction(cmd: Script.Cmd.Action) = when (cmd) {
is Script.Cmd.Action.Invoke -> doInvoke(cmd)
is Script.Cmd.Action.Get -> doGet(cmd)
@ -99,8 +128,8 @@ data class ScriptContext(
instructions = insns
))
)
val name = "expr" + UUID.randomUUID().toString().replace("-", "")
val compiled = compileModule(mod, name, null)
val className = "expr" + UUID.randomUUID().toString().replace("-", "")
val compiled = compileModule(mod, className, null)
return MethodHandles.lookup().bind(compiled.instance, "expr",
MethodType.methodType(retType?.jclass ?: Void.TYPE))
}
@ -163,10 +192,13 @@ data class ScriptContext(
}
}
open class SimpleClassLoader(parent: ClassLoader) : ClassLoader(parent) {
fun fromBuiltContext(ctx: ClsContext) = ctx.cls.withComputedFramesAndMaxs().let { bytes ->
ctx.debug { "ASM Class:\n" + bytes.asClassNode().toAsmString() }
defineClass("${ctx.packageName}.${ctx.className}", bytes, 0, bytes.size)
open class SimpleClassLoader(parent: ClassLoader, logger: Logger) : ClassLoader(parent), Logger by logger {
fun fromBuiltContext(ctx: ClsContext): Class<*> {
ctx.trace { "Computing frames for ASM class:\n" + ctx.cls.toAsmString() }
return ctx.cls.withComputedFramesAndMaxs().let { bytes ->
ctx.debug { "ASM class:\n" + bytes.asClassNode().toAsmString() }
defineClass("${ctx.packageName}.${ctx.className}", bytes, 0, bytes.size)
}
}
}
}

View File

@ -1,5 +1,6 @@
package asmble
import asmble.ast.SExpr
import asmble.io.AstToSExpr
import asmble.io.SExprToStr
import asmble.run.jvm.ScriptContext
@ -12,7 +13,7 @@ import java.io.StringWriter
import kotlin.test.assertEquals
@RunWith(Parameterized::class)
class CoreTest(val unit: CoreTestUnit) : Logger by Logger.Print(Logger.Level.INFO) {
class CoreTest(val unit: CoreTestUnit) : Logger by Logger.Print(Logger.Level.TRACE) {
@Test
fun testName() {
@ -32,7 +33,7 @@ class CoreTest(val unit: CoreTestUnit) : Logger by Logger.Print(Logger.Level.INF
// This will fail assertions as necessary
unit.script.commands.fold(scriptContext, ScriptContext::runCommand)
assertEquals(unit.expectedOutput, out.toString())
unit.expectedOutput?.let { assertEquals(it, out.toString()) }
}
companion object {

View File

@ -39,7 +39,7 @@ class CoreTestUnit(val name: String, val wast: String, val expectedOutput: Strin
val fs = if (uri.scheme == "jar") FileSystems.newFileSystem(uri, emptyMap<String, Any>()) else null
fs.use { fs ->
val path = fs?.getPath(basePath) ?: Paths.get(uri)
return Files.walk(path, 1).filter { it.toString().endsWith("address.wast") }.map {
return Files.walk(path, 1).filter { it.toString().endsWith("block.wast") }.map {
val name = it.fileName.toString().substringBeforeLast(".wast")
CoreTestUnit(
name = name,