asmble/src/main/kotlin/asmble/compile/jvm/InsnReworker.kt

273 lines
16 KiB
Kotlin

package asmble.compile.jvm
import asmble.ast.Node
open class InsnReworker {
fun rework(ctx: ClsContext, func: Node.Func): List<Insn> {
return injectNeededStackVars(ctx, func.instructions).let { insns ->
addEagerLocalInitializers(ctx, func, insns)
}
}
fun addEagerLocalInitializers(ctx: ClsContext, func: Node.Func, insns: List<Insn>): List<Insn> {
if (func.locals.isEmpty()) return insns
// The JVM requires you set a local before you access it. WASM requires that
// all locals are implicitly zero. After some thought, we're going to make this
// an easy algorithm where, for any get_local, there must be a set/tee_local
// in a preceding insn before a branch of any form (br, br_if, and br_table).
// If there isn't, an eager set_local will be added at the beginning to init
// to 0.
//
// This should prevent any false positives (i.e. a get_local before
// a set/tee_local) while reducing false negatives (i.e. a get_local where it
// just doesn't seem like there is a set/tee but there is). Sure there are more
// accurate ways such as specifically injecting sets where needed, or turning
// the first non-set get to a tee, or counting specific block depths, but this
// keeps it simple for now.
//
// Note, while walking backwards up the insns to find set/tee, we do skip entire
// blocks/loops/if+else combined with "end"
var neededEagerLocalIndices = emptySet<Int>()
fun addEagerSetIfNeeded(getInsnIndex: Int, localIndex: Int) {
// Within the param range? nothing needed
if (localIndex < func.type.params.size) return
// Already loading? nothing needed
if (neededEagerLocalIndices.contains(localIndex)) return
var blockInitsToSkip = 0
// Get first set/tee or branching insn (or nothing of course)
val insn = insns.take(getInsnIndex).asReversed().find { insn ->
insn is Insn.Node && when (insn.insn) {
// End means we need to skip to next block start
is Node.Instr.End -> {
blockInitsToSkip++
false
}
// Else with no inits to skip means we are in the else
// and we should skip to the if (i.e. nothing between
// if and else)
is Node.Instr.Else -> {
if (blockInitsToSkip == 0) blockInitsToSkip++
false
}
// Block init, decrement skip count
is Node.Instr.Block, is Node.Instr.Loop, is Node.Instr.If -> {
if (blockInitsToSkip > 0) blockInitsToSkip--
false
}
// Branch means we found it if we're not skipping
is Node.Instr.Br, is Node.Instr.BrIf, is Node.Instr.BrTable ->
blockInitsToSkip == 0
// Set/Tee means we found it if the index is right
// and we're not skipping
is Node.Instr.SetLocal, is Node.Instr.TeeLocal ->
blockInitsToSkip == 0 && (insn.insn as Node.Instr.Args.Index).index == localIndex
// Anything else doesn't matter
else -> false
}
}
// If the insn is not set or tee, we have to eager init
val needsEagerInit = insn == null ||
(insn is Insn.Node && insn.insn !is Node.Instr.SetLocal && insn.insn !is Node.Instr.TeeLocal)
if (needsEagerInit) neededEagerLocalIndices += localIndex
}
insns.forEachIndexed { index, insn ->
if (insn is Insn.Node && insn.insn is Node.Instr.GetLocal) addEagerSetIfNeeded(index, insn.insn.index)
}
// Now, in local order, prepend needed local inits
return neededEagerLocalIndices.sorted().flatMap {
val const: Node.Instr = when (func.localByIndex(it)) {
is Node.Type.Value.I32 -> Node.Instr.I32Const(0)
is Node.Type.Value.I64 -> Node.Instr.I64Const(0)
is Node.Type.Value.F32 -> Node.Instr.F32Const(0f)
is Node.Type.Value.F64 -> Node.Instr.F64Const(0.0)
}
listOf(Insn.Node(const), Insn.Node(Node.Instr.SetLocal(it)))
} + insns
}
fun injectNeededStackVars(ctx: ClsContext, insns: List<Node.Instr>): List<Insn> {
// How we do this:
// We run over each insn, and keep a running list of stack
// manips. If there is an insn that needs something so far back,
// we calc where it needs to be added and keep a running list of
// insn inserts. Then at the end we settle up.
//
// Note, we don't do any injections for things like "this" if
// they aren't needed up the stack (e.g. a simple getfield can
// just aload 0 itself)
// Each pair is first the amount of stack that is changed (0 is
// ignored, push is positive, pull is negative) then the index
// of the insn that caused it. As a special case, if the stack
// is dynamic (i.e. call_indirect
var stackManips = emptyList<Pair<Int, Int>>()
// Keyed by the index to inject. With how the algorithm works, we
// guarantee the value will be in the right order if there are
// multiple for the same index
var insnsToInject = emptyMap<Int, List<Insn>>()
fun injectBeforeLastStackCount(insn: Insn, count: Int) {
ctx.trace { "Injecting $insn back $count stack values" }
fun inject(index: Int) {
insnsToInject += index to (insnsToInject[index]?.let { listOf(insn) + it } ?: listOf(insn))
}
if (count == 0) return inject(stackManips.size)
var countSoFar = 0
var foundUnconditionalJump = false
for ((amountChanged, insnIndex) in stackManips.asReversed()) {
countSoFar += amountChanged
if (!foundUnconditionalJump) foundUnconditionalJump = insns[insnIndex].let { insn ->
insn is Node.Instr.Br || insn is Node.Instr.BrTable ||
insn is Node.Instr.Unreachable || insn is Node.Instr.Return
}
if (countSoFar == count) return inject(insnIndex)
}
// Only consider it a failure if we didn't hit any unconditional jumps
if (!foundUnconditionalJump) throw CompileErr.StackInjectionMismatch(count, insn)
}
// Go over each insn, determining where to inject
insns.forEachIndexed { index, insn ->
// Handle special injection cases
when (insn) {
// Calls require "this" or fn ref before the params
is Node.Instr.Call -> {
val inject =
if (insn.index < ctx.importFuncs.size) Insn.ImportFuncRefNeededOnStack(insn.index)
else Insn.ThisNeededOnStack
injectBeforeLastStackCount(inject, ctx.funcTypeAtIndex(insn.index).params.size)
}
// Indirect calls require "this" before the index
is Node.Instr.CallIndirect ->
injectBeforeLastStackCount(Insn.ThisNeededOnStack, 1)
// Global set requires "this" before the single param
is Node.Instr.SetGlobal -> {
val inject =
if (insn.index < ctx.importGlobals.size) Insn.ImportGlobalSetRefNeededOnStack(insn.index)
else Insn.ThisNeededOnStack
injectBeforeLastStackCount(inject, 1)
}
// Loads require "mem" before the single param
is Node.Instr.I32Load, is Node.Instr.I64Load, is Node.Instr.F32Load, is Node.Instr.F64Load,
is Node.Instr.I32Load8S, is Node.Instr.I32Load8U, is Node.Instr.I32Load16U, is Node.Instr.I32Load16S,
is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U, is Node.Instr.I64Load16S,
is Node.Instr.I64Load32S, is Node.Instr.I64Load32U ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 1)
// Storage requires "mem" before the single param
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32 ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 2)
// Grow memory requires "mem" before the single param
is Node.Instr.GrowMemory ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 1)
else -> { }
}
// Add the current diff
ctx.trace { "Stack diff is ${insnStackDiff(ctx, insn)} for $insn" }
stackManips += insnStackDiff(ctx, insn) to index
}
// Build resulting list
return insns.foldIndexed(emptyList<Insn>()) { index, ret, insn ->
val injections = insnsToInject[index] ?: emptyList()
ret + injections + Insn.Node(insn)
}
}
fun insnStackDiff(ctx: ClsContext, insn: Node.Instr) = when (insn) {
is Node.Instr.Unreachable, is Node.Instr.Nop, is Node.Instr.Block,
is Node.Instr.Loop, is Node.Instr.If, is Node.Instr.Else,
is Node.Instr.End, is Node.Instr.Br, is Node.Instr.BrIf,
is Node.Instr.Return -> NOP
is Node.Instr.BrTable -> POP_PARAM
is Node.Instr.Call -> ctx.funcTypeAtIndex(insn.index).let {
// All calls pop params and any return is a push
(POP_PARAM * it.params.size) + (if (it.ret == null) NOP else PUSH_RESULT)
}
is Node.Instr.CallIndirect -> ctx.typeAtIndex(insn.index).let {
// We add one for the table index
POP_PARAM + (POP_PARAM * it.params.size) + (if (it.ret == null) NOP else PUSH_RESULT)
}
is Node.Instr.Drop -> POP_PARAM
is Node.Instr.Select -> (POP_PARAM * 3) + PUSH_RESULT
is Node.Instr.GetLocal -> PUSH_RESULT
is Node.Instr.SetLocal -> POP_PARAM
is Node.Instr.TeeLocal -> POP_PARAM + PUSH_RESULT
is Node.Instr.GetGlobal -> PUSH_RESULT
is Node.Instr.SetGlobal -> POP_PARAM
is Node.Instr.I32Load, is Node.Instr.I64Load, is Node.Instr.F32Load, is Node.Instr.F64Load,
is Node.Instr.I32Load8S, is Node.Instr.I32Load8U, is Node.Instr.I32Load16U, is Node.Instr.I32Load16S,
is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U, is Node.Instr.I64Load16S,
is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32 -> POP_PARAM
is Node.Instr.CurrentMemory -> PUSH_RESULT
is Node.Instr.GrowMemory -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32Const, is Node.Instr.I64Const,
is Node.Instr.F32Const, is Node.Instr.F64Const -> PUSH_RESULT
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
is Node.Instr.I32DivU, is Node.Instr.I32RemS, is Node.Instr.I32RemU, is Node.Instr.I32And,
is Node.Instr.I32Or, is Node.Instr.I32Xor, is Node.Instr.I32Shl, is Node.Instr.I32ShrS,
is Node.Instr.I32ShrU, is Node.Instr.I32Rotl, is Node.Instr.I32Rotr, is Node.Instr.I32Eq,
is Node.Instr.I32Ne, is Node.Instr.I32LtS, is Node.Instr.I32LeS, is Node.Instr.I32LtU,
is Node.Instr.I32LeU, is Node.Instr.I32GtS, is Node.Instr.I32GeS, is Node.Instr.I32GtU,
is Node.Instr.I32GeU -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.I32Clz, is Node.Instr.I32Ctz, is Node.Instr.I32Popcnt,
is Node.Instr.I32Eqz -> POP_PARAM + PUSH_RESULT
is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS,
is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And,
is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS,
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr, is Node.Instr.I64Eq,
is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS, is Node.Instr.I64LtU,
is Node.Instr.I64LeU, is Node.Instr.I64GtS, is Node.Instr.I64GeS, is Node.Instr.I64GtU,
is Node.Instr.I64GeU -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt,
is Node.Instr.I64Eqz -> POP_PARAM + PUSH_RESULT
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge, is Node.Instr.F32Sqrt, is Node.Instr.F32Min,
is Node.Instr.F32Max, is Node.Instr.F32CopySign -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil,
is Node.Instr.F32Floor, is Node.Instr.F32Trunc, is Node.Instr.F32Nearest -> POP_PARAM + PUSH_RESULT
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge, is Node.Instr.F64Sqrt, is Node.Instr.F64Min,
is Node.Instr.F64Max, is Node.Instr.F64CopySign -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil,
is Node.Instr.F64Floor, is Node.Instr.F64Trunc, is Node.Instr.F64Nearest -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32WrapI64, is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64, is Node.Instr.I64ExtendSI32,
is Node.Instr.I64ExtendUI32, is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32,
is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64, is Node.Instr.F32ConvertSI32,
is Node.Instr.F32ConvertUI32, is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64,
is Node.Instr.F32DemoteF64, is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32,
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64, is Node.Instr.F64PromoteF32,
is Node.Instr.I32ReinterpretF32, is Node.Instr.I64ReinterpretF64, is Node.Instr.F32ReinterpretI32,
is Node.Instr.F64ReinterpretI64 -> POP_PARAM + PUSH_RESULT
}
fun nonAdjacentMemAccesses(insns: List<Insn>) = insns.fold(0 to false) { (count, lastCouldHaveMem), insn ->
val inc =
if (lastCouldHaveMem) 0
else if (insn == Insn.MemNeededOnStack) 1
else if (insn is Insn.Node && insn.insn is Node.Instr.CurrentMemory) 1
else 0
val couldSetMemNext = if (insn !is Insn.Node) false else when (insn.insn) {
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32, is Node.Instr.GrowMemory -> true
else -> false
}
(count + inc) to couldSetMemNext
}.let { (count, _) -> count }
companion object : InsnReworker() {
const val POP_PARAM = -1
const val PUSH_RESULT = 1
const val NOP = 0
}
}