merge original-master to master

This commit is contained in:
Constantine Solovev
2018-08-10 13:10:28 +04:00
12 changed files with 803 additions and 46 deletions

2
.gitignore vendored
View File

@ -16,6 +16,8 @@
/annotations/out
/examples/c-simple/bin
/examples/c-simple/build
/examples/go-simple/bin
/examples/go-simple/build
/examples/rust-simple/Cargo.lock
/examples/rust-simple/bin
/examples/rust-simple/build

View File

@ -67,6 +67,31 @@ project(':examples') {
compileOnly project(':compiler')
}
// Go example helpers
task goToWasm {
doFirst {
mkdir 'build'
exec {
def goFileName = fileTree(dir: '.', includes: ['*.go']).files.iterator().next()
environment 'GOOS': 'js', 'GOARCH': 'wasm'
commandLine 'go', 'build', '-o', 'build/lib.wasm', goFileName
}
}
}
task compileGoWasm(type: JavaExec) {
dependsOn goToWasm
classpath configurations.compileClasspath
main = 'asmble.cli.MainKt'
doFirst {
// args 'help', 'compile'
def outFile = 'build/wasm-classes/' + wasmCompiledClassName.replace('.', '/') + '.class'
file(outFile).parentFile.mkdirs()
args 'compile', 'build/lib.wasm', wasmCompiledClassName, '-out', outFile, '-log', 'debug'
}
}
// Rust example helpers
ext.rustBuildRelease = true
@ -116,30 +141,43 @@ project(':examples') {
}
}
project(':examples:go-simple') {
apply plugin: 'application'
ext.wasmCompiledClassName = 'asmble.generated.GoSimple'
dependencies {
compile files('build/wasm-classes')
}
compileJava {
dependsOn compileGoWasm
}
mainClassName = 'asmble.examples.gosimple.Main'
}
// todo temporary disable Rust regex, because some strings in wasm code exceed the size in 65353 bytes.
//project(':examples:rust-regex') {
// apply plugin: 'application'
// apply plugin: 'me.champeau.gradle.jmh'
// ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
// dependencies {
// compile files('build/wasm-classes')
// testCompile 'junit:junit:4.12'
// }
// compileJava {
// dependsOn compileRustWasm
// }
// mainClassName = 'asmble.examples.rustregex.Main'
// test {
// testLogging.showStandardStreams = true
// testLogging.events 'PASSED', 'SKIPPED'
// }
// jmh {
// iterations = 5
// warmupIterations = 5
// fork = 3
// }
//}
// project(':examples:rust-regex') {
// apply plugin: 'application'
// apply plugin: 'me.champeau.gradle.jmh'
// ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
// dependencies {
// compile files('build/wasm-classes')
// testCompile 'junit:junit:4.12'
// }
// compileJava {
// dependsOn compileRustWasm
// }
// mainClassName = 'asmble.examples.rustregex.Main'
// test {
// testLogging.showStandardStreams = true
// testLogging.events 'PASSED', 'SKIPPED'
// }
// jmh {
// iterations = 5
// warmupIterations = 5
// fork = 3
// }
// }
project(':examples:rust-simple') {
apply plugin: 'application'

View File

@ -0,0 +1,250 @@
package asmble.ast
// This is a utility for walking the stack. It can do validation or just walk naively.
data class Stack(
// If some of these values below are null, the pops/pushes may appear "unknown"
val mod: CachedModule? = null,
val func: Node.Func? = null,
// Null if not tracking the current stack and all pops succeed
val current: List<Node.Type.Value>? = null,
val insnApplies: List<InsnApply> = emptyList(),
val strict: Boolean = false,
val unreachableUntilNextEndCount: Int = 0
) {
fun next(v: Node.Instr, callFuncTypeOverride: Node.Type.Func? = null) = insnApply(v) {
// If we're unreachable, and not an end, we skip and move on
if (unreachableUntilNextEndCount > 0 && v !is Node.Instr.End) {
// If it's a block, we increase it because we'll see another end
return@insnApply if (v is Node.Instr.Args.Type) unreachable(unreachableUntilNextEndCount + 1) else nop()
}
when (v) {
is Node.Instr.Nop, is Node.Instr.Block, is Node.Instr.Loop -> nop()
is Node.Instr.If, is Node.Instr.BrIf -> popI32()
is Node.Instr.Return -> (func?.type?.ret?.let { pop(it) } ?: nop()) + unreachable(1)
is Node.Instr.Unreachable -> unreachable(1)
is Node.Instr.End, is Node.Instr.Else -> {
// Put back what was before the last block and add the block's type
// Go backwards to find the starting block
var currDepth = 0
val found = insnApplies.findLast {
when (it.insn) {
is Node.Instr.End -> { currDepth++; false }
is Node.Instr.Args.Type -> if (currDepth > 0) { currDepth--; false } else true
else -> false
}
}?.takeIf {
// When it's else, needs to be if
v !is Node.Instr.Else || it.insn is Node.Instr.If
}
val changes = when {
found != null && found.insn is Node.Instr.Args.Type &&
found.stackAtBeginning != null && this != null -> {
// Pop everything from before the block's start, then push if necessary...
// The If block includes an int at the beginning we must not include when subtracting
var preBlockStackSize = found.stackAtBeginning.size
if (found.insn is Node.Instr.If) preBlockStackSize--
val popped =
if (unreachableUntilNextEndCount > 1) nop()
else (0 until (size - preBlockStackSize)).flatMap { pop() }
// Only push if this is not an else
val pushed =
if (unreachableUntilNextEndCount > 1 || v is Node.Instr.Else) nop()
else (found.insn.type?.let { push(it) } ?: nop())
popped + pushed
}
strict -> error("Unable to find starting block for end")
else -> nop()
}
if (unreachableUntilNextEndCount > 0) changes + unreachable(unreachableUntilNextEndCount - 1)
else changes
}
is Node.Instr.Br -> unreachable(v.relativeDepth + 1)
is Node.Instr.BrTable -> popI32() + unreachable(1)
is Node.Instr.Call -> (callFuncTypeOverride ?: func(v.index)).let {
if (it == null) error("Call func type missing")
it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.CallIndirect -> (callFuncTypeOverride ?: mod?.mod?.types?.getOrNull(v.index)).let {
if (it == null) error("Call func type missing")
// We add one for the table index
popI32() + it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.Drop -> pop()
is Node.Instr.Select -> popI32() + pop().let { it + pop(it.first().type) + push(it.first().type) }
is Node.Instr.GetLocal -> push(local(v.index))
is Node.Instr.SetLocal -> pop(local(v.index))
is Node.Instr.TeeLocal -> local(v.index).let { pop(it) + push(it) }
is Node.Instr.GetGlobal -> push(global(v.index))
is Node.Instr.SetGlobal -> pop(global(v.index))
is Node.Instr.I32Load, is Node.Instr.I32Load8S, is Node.Instr.I32Load8U,
is Node.Instr.I32Load16U, is Node.Instr.I32Load16S -> popI32() + pushI32()
is Node.Instr.I64Load, is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U,
is Node.Instr.I64Load16S, is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> popI32() + pushI64()
is Node.Instr.F32Load -> popI32() + pushF32()
is Node.Instr.F64Load -> popI32() + pushF64()
is Node.Instr.I32Store, is Node.Instr.I32Store8, is Node.Instr.I32Store16 -> popI32() + popI32()
is Node.Instr.I64Store, is Node.Instr.I64Store8,
is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI64() + popI32()
is Node.Instr.F32Store -> popF32() + popI32()
is Node.Instr.F64Store -> popF64() + popI32()
is Node.Instr.MemorySize -> pushI32()
is Node.Instr.MemoryGrow -> popI32() + pushI32()
is Node.Instr.I32Const -> pushI32()
is Node.Instr.I64Const -> pushI64()
is Node.Instr.F32Const -> pushF32()
is Node.Instr.F64Const -> pushF64()
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
is Node.Instr.I32DivU, is Node.Instr.I32RemS, is Node.Instr.I32RemU, is Node.Instr.I32And,
is Node.Instr.I32Or, is Node.Instr.I32Xor, is Node.Instr.I32Shl, is Node.Instr.I32ShrS,
is Node.Instr.I32ShrU, is Node.Instr.I32Rotl, is Node.Instr.I32Rotr, is Node.Instr.I32Eq,
is Node.Instr.I32Ne, is Node.Instr.I32LtS, is Node.Instr.I32LeS, is Node.Instr.I32LtU,
is Node.Instr.I32LeU, is Node.Instr.I32GtS, is Node.Instr.I32GeS, is Node.Instr.I32GtU,
is Node.Instr.I32GeU -> popI32() + popI32() + pushI32()
is Node.Instr.I32Clz, is Node.Instr.I32Ctz, is Node.Instr.I32Popcnt,
is Node.Instr.I32Eqz -> popI32() + pushI32()
is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS,
is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And,
is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS,
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr -> popI64() + popI64() + pushI64()
is Node.Instr.I64Eq, is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS,
is Node.Instr.I64LtU, is Node.Instr.I64LeU, is Node.Instr.I64GtS,
is Node.Instr.I64GeS, is Node.Instr.I64GtU, is Node.Instr.I64GeU -> popI64() + popI64() + pushI32()
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt -> popI64() + pushI64()
is Node.Instr.I64Eqz -> popI64() + pushI32()
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
is Node.Instr.F32Min, is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32()
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge -> popF32() + popF32() + pushI32()
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor,
is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> popF32() + pushF32()
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
is Node.Instr.F64Min, is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64()
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge -> popF64() + popF64() + pushI32()
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor,
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + pushF64()
is Node.Instr.I32WrapI64 -> popI64() + pushI32()
is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
is Node.Instr.I32ReinterpretF32 -> popF32() + pushI32()
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popF64() + pushI32()
is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI32() + pushI64()
is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popF32() + pushI64()
is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64,
is Node.Instr.I64ReinterpretF64 -> popF64() + pushI64()
is Node.Instr.F32ConvertSI32, is Node.Instr.F32ConvertUI32,
is Node.Instr.F32ReinterpretI32 -> popI32() + pushF32()
is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64 -> popI64() + pushF32()
is Node.Instr.F32DemoteF64 -> popF64() + pushF32()
is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popI32() + pushF64()
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64,
is Node.Instr.F64ReinterpretI64 -> popI64() + pushF64()
is Node.Instr.F64PromoteF32 -> popF32() + pushF64()
}
}
protected fun insnApply(v: Node.Instr, fn: MutableList<Node.Type.Value>?.() -> List<InsnApplyResponse>): Stack {
val mutStack = current?.toMutableList()
val applyResp = mutStack.fn()
val newUnreachable = (applyResp.find { it is Unreachable } as? Unreachable)?.untilEndCount
return copy(
current = mutStack,
insnApplies = insnApplies + InsnApply(
insn = v,
stackAtBeginning = current,
stackChanges = applyResp.mapNotNull { it as? StackChange },
unreachableUntilEndCount = newUnreachable ?: unreachableUntilNextEndCount
),
unreachableUntilNextEndCount = newUnreachable ?: unreachableUntilNextEndCount
)
}
protected fun unreachable(untilEndCount: Int) = listOf(Unreachable(untilEndCount))
protected fun local(index: Int) = func?.let {
it.type.params.getOrNull(index) ?: it.locals.getOrNull(index - it.type.params.size)
}
protected fun global(index: Int) = mod?.let {
it.importGlobals.getOrNull(index)?.type?.contentType ?:
it.mod.globals.getOrNull(index - it.importGlobals.size)?.type?.contentType
}
protected fun func(index: Int) = mod?.let {
it.importFuncs.getOrNull(index)?.typeIndex?.let { i -> it.mod.types.getOrNull(i) } ?:
it.mod.funcs.getOrNull(index - it.importFuncs.size)?.type
}
protected fun nop() = emptyList<StackChange>()
protected fun MutableList<Node.Type.Value>?.popType(expecting: Node.Type.Value? = null) =
this?.takeIf {
it.isNotEmpty().also {
require(!strict || it) { "Expected $expecting got empty" }
}
}?.let {
removeAt(size - 1).takeIf { actual -> (expecting == null || actual == expecting).also {
require(!strict || it) { "Expected $expecting got $actual" }
} }
} ?: expecting
protected fun MutableList<Node.Type.Value>?.pop(expecting: Node.Type.Value? = null) =
listOf(StackChange(popType(expecting), true))
protected fun MutableList<Node.Type.Value>?.popI32() = pop(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.popI64() = pop(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.popF32() = pop(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.popF64() = pop(Node.Type.Value.F64)
protected fun MutableList<Node.Type.Value>?.push(type: Node.Type.Value? = null) =
listOf(StackChange(type, false)).also { if (this != null && type != null) add(type) }
protected fun MutableList<Node.Type.Value>?.pushI32() = push(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.pushI64() = push(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.pushF32() = push(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.pushF64() = push(Node.Type.Value.F64)
data class InsnApply(
val insn: Node.Instr,
val stackAtBeginning: List<Node.Type.Value>?,
val stackChanges: List<StackChange>,
val unreachableUntilEndCount: Int
)
protected interface InsnApplyResponse
data class StackChange(
val type: Node.Type.Value?,
val pop: Boolean
) : InsnApplyResponse
data class Unreachable(
val untilEndCount: Int
) : InsnApplyResponse
class CachedModule(val mod: Node.Module) {
val importFuncs by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Func } }
val importGlobals by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Global } }
}
companion object {
fun walkStrict(mod: Node.Module, func: Node.Func, afterInsn: ((Stack, Node.Instr) -> Unit)? = null) =
func.instructions.fold(Stack(
mod = CachedModule(mod),
func = func,
current = emptyList(),
strict = true
)) { stack, insn -> stack.next(insn).also { afterInsn?.invoke(it, insn) } }.also { stack ->
// We expect to be in an unreachable state at the end or have the single return value on the stack
if (stack.unreachableUntilNextEndCount == 0) {
val expectedStack = (func.type.ret?.let { listOf(it) } ?: emptyList())
require(expectedStack == stack.current) {
"Expected end to be $expectedStack, got ${stack.current}"
}
}
}
fun stackChanges(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
Stack().next(v, callFuncType).insnApplies.last().stackChanges
fun stackChanges(mod: CachedModule, func: Node.Func, v: Node.Instr) =
Stack(mod, func).next(v).insnApplies.last().stackChanges
fun stackDiff(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
stackChanges(v, callFuncType).sumBy { if (it.pop) -1 else 1 }
fun stackDiff(mod: CachedModule, func: Node.Func, v: Node.Instr) =
stackChanges(mod, func, v).sumBy { if (it.pop) -1 else 1 }
}
}

View File

@ -0,0 +1,183 @@
package asmble.ast.opt
import asmble.ast.Node
import asmble.ast.Stack
// This is a naive implementation that just grabs adjacent sets of restricted insns and breaks the one that will save
// the most instructions off into its own function.
open class SplitLargeFunc(
val minSetLength: Int = 5,
val maxSetLength: Int = 40,
val maxParamCount: Int = 30
) {
// Null if no replacement. Second value is number of instructions saved. fnIndex must map to actual func,
// not imported one.
fun apply(mod: Node.Module, fnIndex: Int): Pair<Node.Module, Int>? {
// Get the func
val importFuncCount = mod.imports.count { it.kind is Node.Import.Kind.Func }
val actualFnIndex = fnIndex - importFuncCount
val func = mod.funcs.getOrElse(actualFnIndex) {
error("Unable to find non-import func at $fnIndex (actual $actualFnIndex)")
}
// Just take the best pattern and apply it
val newFuncIndex = importFuncCount + mod.funcs.size
return commonPatterns(mod, func).firstOrNull()?.let { pattern ->
// Name it as <funcname>$splitN (n is num just to disambiguate) if names are part of the mod
val newName = mod.names?.funcNames?.get(fnIndex)?.let {
"$it\$split".let { it + mod.names.funcNames.count { (_, v) -> v.startsWith(it) } }
}
// Go over every replacement in reverse, changing the instructions to our new set
val newInsns = pattern.replacements.foldRight(func.instructions) { repl, insns ->
insns.take(repl.range.start) +
repl.preCallConsts +
Node.Instr.Call(newFuncIndex) +
insns.drop(repl.range.endInclusive + 1)
}
// Return the module w/ the new function, it's new name, and the insns saved
mod.copy(
funcs = mod.funcs.toMutableList().also {
it[actualFnIndex] = func.copy(instructions = newInsns)
} + pattern.newFunc,
names = mod.names?.copy(funcNames = mod.names.funcNames.toMutableMap().also {
it[newFuncIndex] = newName!!
})
) to pattern.insnsSaved
}
}
// Results are by most insns saved. There can be overlap across patterns but never within a single pattern.
fun commonPatterns(mod: Node.Module, fn: Node.Func): List<CommonPattern> {
// Walk the stack for validation needs
val stack = Stack.walkStrict(mod, fn)
// Let's grab sets of insns that qualify. In this naive impl, in order to qualify the insn set needs to
// only have a certain set of insns that can be broken off. It can also only change the stack by 0 or 1
// value while never dipping below the starting stack. We also store the index they started at.
var insnSets = emptyList<InsnSet>()
// Pair in fold keyed by insn index
fn.instructions.foldIndexed(null as List<Pair<Int, Node.Instr>>?) { index, lastInsns, insn ->
if (!insn.canBeMoved) null else (lastInsns ?: emptyList()).plus(index to insn).also { fullNewInsnSet ->
// Get all final instructions between min and max size and with allowed param count (i.e. const count)
val trailingInsnSet = fullNewInsnSet.takeLast(maxSetLength)
// Get all instructions between the min and max
insnSets += (minSetLength..maxSetLength).
asSequence().
flatMap { trailingInsnSet.asSequence().windowed(it) }.
filter { it.count { it.second is Node.Instr.Args.Const<*> } <= maxParamCount }.
mapNotNull { newIndexedInsnSet ->
// Before adding, make sure it qualifies with the stack
InsnSet(
startIndex = newIndexedInsnSet.first().first,
insns = newIndexedInsnSet.map { it.second },
valueAddedToStack = null
).withStackValueIfValid(stack)
}
}
}
// Sort the insn sets by the ones with the most insns
insnSets = insnSets.sortedByDescending { it.insns.size }
// Now let's create replacements for each, keyed by the extracted func
val patterns = insnSets.fold(emptyMap<Node.Func, List<Replacement>>()) { map, insnSet ->
insnSet.extractCommonFunc().let { (func, replacement) ->
val existingReplacements = map.getOrDefault(func, emptyList())
// Ignore if there is any overlap
if (existingReplacements.any(replacement::overlaps)) map
else map + (func to existingReplacements.plus(replacement))
}
}
// Now sort the patterns by most insns saved and return
return patterns.map { (k, v) ->
CommonPattern(k, v.sortedBy { it.range.first })
}.sortedByDescending { it.insnsSaved }
}
val Node.Instr.canBeMoved get() =
// No blocks
this !is Node.Instr.Block && this !is Node.Instr.Loop && this !is Node.Instr.If &&
this !is Node.Instr.Else && this !is Node.Instr.End &&
// No breaks
this !is Node.Instr.Br && this !is Node.Instr.BrIf && this !is Node.Instr.BrTable &&
// No return
this !is Node.Instr.Return &&
// No local access
this !is Node.Instr.GetLocal && this !is Node.Instr.SetLocal && this !is Node.Instr.TeeLocal
fun InsnSet.withStackValueIfValid(stack: Stack): InsnSet? {
// This makes sure that the stack only changes by at most one item and never dips below its starting val.
// If it is invalid, null is returned. If it qualifies and does change 1 value, it is set.
// First, make sure the stack after the last insn is the same as the first or the same + 1 val
val startingStack = stack.insnApplies[startIndex].stackAtBeginning!!
val endingStack = stack.insnApplies.getOrNull(startIndex + insns.size)?.stackAtBeginning ?: stack.current!!
if (endingStack.size != startingStack.size && endingStack.size != startingStack.size + 1) return null
if (endingStack.take(startingStack.size) != startingStack) return null
// Now, walk the insns and make sure they never pop below the start
var stackCounter = 0
stack.insnApplies.subList(startIndex, startIndex + insns.size).forEach {
it.stackChanges.forEach {
stackCounter += if (it.pop) -1 else 1
if (stackCounter < 0) return null
}
}
// We're good, now only if the ending stack is one over the start do we have a ret val
return copy(
valueAddedToStack = endingStack.lastOrNull()?.takeIf { endingStack.size == startingStack.size + 1 }
)
}
fun InsnSet.extractCommonFunc() =
// This extracts a function with constants changed to parameters
insns.fold(Pair(
Node.Func(Node.Type.Func(params = emptyList(), ret = valueAddedToStack), emptyList(), emptyList()),
Replacement(range = startIndex until startIndex + insns.size, preCallConsts = emptyList()))
) { (func, repl), insn ->
if (insn !is Node.Instr.Args.Const<*>) func.copy(instructions = func.instructions + insn) to repl
else func.copy(
type = func.type.copy(params = func.type.params + insn.constType),
instructions = func.instructions + Node.Instr.GetLocal(func.type.params.size)
) to repl.copy(preCallConsts = repl.preCallConsts + insn)
}
protected val Node.Instr.Args.Const<*>.constType get() = when (this) {
is Node.Instr.I32Const -> Node.Type.Value.I32
is Node.Instr.I64Const -> Node.Type.Value.I64
is Node.Instr.F32Const -> Node.Type.Value.F32
is Node.Instr.F64Const -> Node.Type.Value.F64
else -> error("unreachable")
}
data class InsnSet(
val startIndex: Int,
val insns: List<Node.Instr>,
val valueAddedToStack: Node.Type.Value?
)
data class Replacement(
val range: IntRange,
val preCallConsts: List<Node.Instr>
) {
// Subtract one because there is a call after this
val insnsSaved get() = (range.last + 1) - range.first - 1 - preCallConsts.size
fun overlaps(o: Replacement) = range.contains(o.range.first) || range.contains(o.range.last) ||
o.range.contains(range.first) || o.range.contains(range.last)
}
data class CommonPattern(
val newFunc: Node.Func,
// In order by earliest replacement first
val replacements: List<Replacement>
) {
// Replacement pieces saved (with one added for the invocation) less new func instructions
val insnsSaved get() = replacements.sumBy { it.insnsSaved } - newFunc.instructions.size
}
companion object : SplitLargeFunc()
}

View File

@ -3,7 +3,7 @@ package asmble.cli
import asmble.util.Logger
import kotlin.system.exitProcess
val commands = listOf(Compile, Help, Invoke, Link, Run, Translate)
val commands = listOf(Compile, Help, Invoke, Link, Run, SplitFunc, Translate)
/**
* Entry point of command line interface.

View File

@ -0,0 +1,146 @@
package asmble.cli
import asmble.ast.Node
import asmble.ast.Script
import asmble.ast.opt.SplitLargeFunc
open class SplitFunc : Command<SplitFunc.Args>() {
override val name = "split-func"
override val desc = "Split a WebAssembly function into two"
override fun args(bld: Command.ArgsBuilder) = Args(
inFile = bld.arg(
name = "inFile",
desc = "The wast or wasm WebAssembly file name. Can be '--' to read from stdin."
),
funcName = bld.arg(
name = "funcName",
desc = "The name (or '#' + function space index) of the function to split"
),
inFormat = bld.arg(
name = "inFormat",
opt = "in",
desc = "Either 'wast' or 'wasm' to describe format.",
default = "<use file extension>",
lowPriority = true
),
outFile = bld.arg(
name = "outFile",
opt = "outFile",
desc = "The wast or wasm WebAssembly file name. Can be '--' to write to stdout.",
default = "<inFileSansExt.split.wasm or stdout>",
lowPriority = true
),
outFormat = bld.arg(
name = "outFormat",
opt = "out",
desc = "Either 'wast' or 'wasm' to describe format.",
default = "<use file extension or wast for stdout>",
lowPriority = true
),
compact = bld.flag(
opt = "compact",
desc = "If set for wast out format, will be compacted.",
lowPriority = true
),
minInsnSetLength = bld.arg(
name = "minInsnSetLength",
opt = "minLen",
desc = "The minimum number of instructions allowed for the split off function.",
default = "5",
lowPriority = true
).toInt(),
maxInsnSetLength = bld.arg(
name = "maxInsnSetLength",
opt = "maxLen",
desc = "The maximum number of instructions allowed for the split off function.",
default = "40",
lowPriority = true
).toInt(),
maxNewFuncParamCount = bld.arg(
name = "maxNewFuncParamCount",
opt = "maxParams",
desc = "The maximum number of params allowed for the split off function.",
default = "30",
lowPriority = true
).toInt(),
attempts = bld.arg(
name = "attempts",
opt = "attempts",
desc = "The number of attempts to perform.",
default = "1",
lowPriority = true
).toInt()
).also { bld.done() }
override fun run(args: Args) {
// Load the mod
val translate = Translate().also { it.logger = logger }
val inFormat =
if (args.inFormat != "<use file extension>") args.inFormat
else args.inFile.substringAfterLast('.', "<unknown>")
val script = translate.inToAst(args.inFile, inFormat)
var mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?: error("Only a single module allowed")
// Do attempts
val splitter = SplitLargeFunc(
minSetLength = args.minInsnSetLength,
maxSetLength = args.maxInsnSetLength,
maxParamCount = args.maxNewFuncParamCount
)
for (attempt in 0 until args.attempts) {
// Find the function
var index = mod.names?.funcNames?.toList()?.find { it.second == args.funcName }?.first
if (index == null && args.funcName.startsWith('#')) index = args.funcName.drop(1).toInt()
val origFunc = index?.let {
mod.funcs.getOrNull(it - mod.imports.count { it.kind is Node.Import.Kind.Func })
} ?: error("Unable to find func")
// Split it
val results = splitter.apply(mod, index)
if (results == null) {
logger.warn { "No instructions after attempt $attempt" }
break
}
val (splitMod, insnsSaved) = results
val newFunc = splitMod.funcs[index - mod.imports.count { it.kind is Node.Import.Kind.Func }]
val splitFunc = splitMod.funcs.last()
logger.warn {
"Split complete, from func with ${origFunc.instructions.size} insns to a func " +
"with ${newFunc.instructions.size} insns + delegated func " +
"with ${splitFunc.instructions.size} insns and ${splitFunc.type.params.size} params, " +
"saved $insnsSaved insns"
}
mod = splitMod
}
// Write it
val outFile = when {
args.outFile != "<inFileSansExt.split.wasm or stdout>" -> args.outFile
args.inFile == "--" -> "--"
else -> args.inFile.replaceAfterLast('.', "split." + args.inFile.substringAfterLast('.'))
}
val outFormat = when {
args.outFormat != "<use file extension or wast for stdout>" -> args.outFormat
outFile == "--" -> "wast"
else -> outFile.substringAfterLast('.', "<unknown>")
}
translate.astToOut(outFile, outFormat, args.compact,
Script(listOf(Script.Cmd.Module(mod, mod.names?.moduleName))))
}
data class Args(
val inFile: String,
val inFormat: String,
val funcName: String,
val outFile: String,
val outFormat: String,
val compact: Boolean,
val minInsnSetLength: Int,
val maxInsnSetLength: Int,
val maxNewFuncParamCount: Int,
val attempts: Int
)
companion object : SplitFunc()
}

View File

@ -52,24 +52,7 @@ open class Translate : Command<Translate.Args>() {
if (args.outFormat != "<use file extension or wast for stdout>") args.outFormat
else if (args.outFile == "--") "wast"
else args.outFile.substringAfterLast('.', "<unknown>")
val outStream =
if (args.outFile == "--") System.out
else FileOutputStream(args.outFile)
outStream.use { outStream ->
when (outFormat) {
"wast" -> {
val sexprToStr = if (args.compact) SExprToStr.Compact else SExprToStr
val sexprs = AstToSExpr.fromScript(script)
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
}
"wasm" -> {
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
error("Output to WASM requires input be just a single module")
AstToBinary.fromModule(ByteWriter.OutputStream(outStream), mod)
}
else -> error("Unknown out format '$outFormat'")
}
}
astToOut(args.outFile, outFormat, args.compact, script)
}
fun inToAst(inFile: String, inFormat: String): Script {
@ -93,6 +76,27 @@ open class Translate : Command<Translate.Args>() {
}
}
fun astToOut(outFile: String, outFormat: String, compact: Boolean, script: Script) {
val outStream =
if (outFile == "--") System.out
else FileOutputStream(outFile)
outStream.use { outStream ->
when (outFormat) {
"wast" -> {
val sexprToStr = if (compact) SExprToStr.Compact else SExprToStr
val sexprs = AstToSExpr.fromScript(script)
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
}
"wasm" -> {
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
error("Output to WASM requires input be just a single module")
AstToBinary.fromModule(ByteWriter.OutputStream(outStream), mod)
}
else -> error("Unknown out format '$outFormat'")
}
}
}
data class Args(
val inFile: String,
val inFormat: String,

View File

@ -0,0 +1,33 @@
package asmble.ast
import asmble.SpecTestUnit
import asmble.TestBase
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class StackTest(val unit: SpecTestUnit) : TestBase() {
@Test
fun testStack() {
// If it's not a module expecting an error, we'll try to walk the stack on each function
unit.script.commands.mapNotNull { it as? Script.Cmd.Module }.forEach { mod ->
mod.module.funcs.filter { it.instructions.isNotEmpty() }.forEach { func ->
debug { "Func: ${func.type}" }
var indexCount = 0
Stack.walkStrict(mod.module, func) { stack, insn ->
debug { " After $insn (next: ${func.instructions.getOrNull(++indexCount)}, " +
"unreach depth: ${stack.unreachableUntilNextEndCount})" }
debug { " " + stack.current }
}
}
}
}
companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }//.filter { it.name == "loop" }
}
}

View File

@ -0,0 +1,89 @@
package asmble.ast.opt
import asmble.TestBase
import asmble.ast.Node
import asmble.compile.jvm.AstToAsm
import asmble.compile.jvm.ClsContext
import asmble.run.jvm.ScriptContext
import org.junit.Test
import java.nio.ByteBuffer
import java.util.*
import kotlin.test.assertEquals
class SplitLargeFuncTest : TestBase() {
@Test
fun testSplitLargeFunc() {
// We're going to make a large function that does some addition and then stores in mem
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
logger = logger,
mod = Node.Module(
memories = listOf(Node.Type.Memory(Node.ResizableLimits(initial = 2, maximum = 2))),
funcs = listOf(Node.Func(
type = Node.Type.Func(params = emptyList(), ret = null),
locals = emptyList(),
instructions = (0 until 501).flatMap {
listOf<Node.Instr>(
Node.Instr.I32Const(it * 4),
// Let's to i * (i = 1)
Node.Instr.I32Const(it),
Node.Instr.I32Const(it - 1),
Node.Instr.I32Mul,
Node.Instr.I32Store(0, 0)
)
}
)),
names = Node.NameSection(
moduleName = null,
funcNames = mapOf(0 to "someFunc"),
localNames = emptyMap()
),
exports = listOf(
Node.Export("memory", Node.ExternalKind.MEMORY, 0),
Node.Export("someFunc", Node.ExternalKind.FUNCTION, 0)
)
)
)
// Compile it
AstToAsm.fromModule(ctx)
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
val inst = cls.newInstance()
// Run someFunc
cls.getMethod("someFunc").invoke(inst)
// Get the memory out
val mem = cls.getMethod("getMemory").invoke(inst) as ByteBuffer
// Read out the mem values
(0 until 501).forEach { assertEquals(it * (it - 1), mem.getInt(it * 4)) }
// Now split it
val (splitMod, insnsSaved) = SplitLargeFunc.apply(ctx.mod, 0) ?: error("Nothing could be split")
// Count insns and confirm it is as expected
val origInsnCount = ctx.mod.funcs.sumBy { it.instructions.size }
val newInsnCount = splitMod.funcs.sumBy { it.instructions.size }
assertEquals(origInsnCount - newInsnCount, insnsSaved)
// Compile it
val splitCtx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
logger = logger,
mod = splitMod
)
AstToAsm.fromModule(splitCtx)
val splitCls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(splitCtx)
val splitInst = splitCls.newInstance()
// Run someFunc
splitCls.getMethod("someFunc").invoke(splitInst)
// Get the memory out and compare it
val splitMem = splitCls.getMethod("getMemory").invoke(splitInst) as ByteBuffer
assertEquals(mem, splitMem)
// Dump some info
logger.debug {
val orig = ctx.mod.funcs.first()
val (new, split) = splitMod.funcs.let { it.first() to it.last() }
"Split complete, from single func with ${orig.instructions.size} insns to func " +
"with ${new.instructions.size} insns + delegated func " +
"with ${split.instructions.size} insns and ${split.type.params.size} params"
}
}
}

View File

@ -1,6 +1,7 @@
package asmble.io
import asmble.SpecTestUnit
import asmble.TestBase
import asmble.ast.Node
import asmble.ast.Script
import asmble.util.Logger
@ -13,12 +14,10 @@ import java.io.ByteArrayOutputStream
import kotlin.test.assertEquals
@RunWith(Parameterized::class)
class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) {
class IoTest(val unit: SpecTestUnit) : TestBase() {
@Test
fun testIo() {
// Ignore things that are supposed to fail
if (unit.shouldFail) return
// Go from the AST to binary then back to AST then back to binary and confirm values are as expected
val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also {
trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) }
@ -46,7 +45,8 @@ class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO)
}
companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }
}
}

View File

@ -0,0 +1,10 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Printf("Args: %v", os.Args)
}

View File

@ -1,6 +1,8 @@
rootProject.name = 'asmble'
include 'annotations',
'compiler',
// 'examples:rust-regex', // todo will be enabled when the problem with string max size will be solved
'examples:c-simple',
'examples:go-simple',
'examples:rust-regex',
'examples:rust-simple',
'examples:rust-string'