mirror of
https://github.com/fluencelabs/asmble
synced 2025-07-05 01:11:37 +00:00
merge original-master to master
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -16,6 +16,8 @@
|
|||||||
/annotations/out
|
/annotations/out
|
||||||
/examples/c-simple/bin
|
/examples/c-simple/bin
|
||||||
/examples/c-simple/build
|
/examples/c-simple/build
|
||||||
|
/examples/go-simple/bin
|
||||||
|
/examples/go-simple/build
|
||||||
/examples/rust-simple/Cargo.lock
|
/examples/rust-simple/Cargo.lock
|
||||||
/examples/rust-simple/bin
|
/examples/rust-simple/bin
|
||||||
/examples/rust-simple/build
|
/examples/rust-simple/build
|
||||||
|
42
build.gradle
42
build.gradle
@ -67,6 +67,31 @@ project(':examples') {
|
|||||||
compileOnly project(':compiler')
|
compileOnly project(':compiler')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Go example helpers
|
||||||
|
|
||||||
|
task goToWasm {
|
||||||
|
doFirst {
|
||||||
|
mkdir 'build'
|
||||||
|
exec {
|
||||||
|
def goFileName = fileTree(dir: '.', includes: ['*.go']).files.iterator().next()
|
||||||
|
environment 'GOOS': 'js', 'GOARCH': 'wasm'
|
||||||
|
commandLine 'go', 'build', '-o', 'build/lib.wasm', goFileName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task compileGoWasm(type: JavaExec) {
|
||||||
|
dependsOn goToWasm
|
||||||
|
classpath configurations.compileClasspath
|
||||||
|
main = 'asmble.cli.MainKt'
|
||||||
|
doFirst {
|
||||||
|
// args 'help', 'compile'
|
||||||
|
def outFile = 'build/wasm-classes/' + wasmCompiledClassName.replace('.', '/') + '.class'
|
||||||
|
file(outFile).parentFile.mkdirs()
|
||||||
|
args 'compile', 'build/lib.wasm', wasmCompiledClassName, '-out', outFile, '-log', 'debug'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Rust example helpers
|
// Rust example helpers
|
||||||
|
|
||||||
ext.rustBuildRelease = true
|
ext.rustBuildRelease = true
|
||||||
@ -116,9 +141,22 @@ project(':examples') {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
project(':examples:go-simple') {
|
||||||
|
apply plugin: 'application'
|
||||||
|
ext.wasmCompiledClassName = 'asmble.generated.GoSimple'
|
||||||
|
dependencies {
|
||||||
|
compile files('build/wasm-classes')
|
||||||
|
}
|
||||||
|
compileJava {
|
||||||
|
dependsOn compileGoWasm
|
||||||
|
}
|
||||||
|
mainClassName = 'asmble.examples.gosimple.Main'
|
||||||
|
}
|
||||||
|
|
||||||
// todo temporary disable Rust regex, because some strings in wasm code exceed the size in 65353 bytes.
|
// todo temporary disable Rust regex, because some strings in wasm code exceed the size in 65353 bytes.
|
||||||
|
|
||||||
//project(':examples:rust-regex') {
|
// project(':examples:rust-regex') {
|
||||||
// apply plugin: 'application'
|
// apply plugin: 'application'
|
||||||
// apply plugin: 'me.champeau.gradle.jmh'
|
// apply plugin: 'me.champeau.gradle.jmh'
|
||||||
// ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
|
// ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
|
||||||
@ -139,7 +177,7 @@ project(':examples') {
|
|||||||
// warmupIterations = 5
|
// warmupIterations = 5
|
||||||
// fork = 3
|
// fork = 3
|
||||||
// }
|
// }
|
||||||
//}
|
// }
|
||||||
|
|
||||||
project(':examples:rust-simple') {
|
project(':examples:rust-simple') {
|
||||||
apply plugin: 'application'
|
apply plugin: 'application'
|
||||||
|
250
compiler/src/main/kotlin/asmble/ast/Stack.kt
Normal file
250
compiler/src/main/kotlin/asmble/ast/Stack.kt
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
package asmble.ast
|
||||||
|
|
||||||
|
// This is a utility for walking the stack. It can do validation or just walk naively.
|
||||||
|
data class Stack(
|
||||||
|
// If some of these values below are null, the pops/pushes may appear "unknown"
|
||||||
|
val mod: CachedModule? = null,
|
||||||
|
val func: Node.Func? = null,
|
||||||
|
// Null if not tracking the current stack and all pops succeed
|
||||||
|
val current: List<Node.Type.Value>? = null,
|
||||||
|
val insnApplies: List<InsnApply> = emptyList(),
|
||||||
|
val strict: Boolean = false,
|
||||||
|
val unreachableUntilNextEndCount: Int = 0
|
||||||
|
) {
|
||||||
|
|
||||||
|
fun next(v: Node.Instr, callFuncTypeOverride: Node.Type.Func? = null) = insnApply(v) {
|
||||||
|
// If we're unreachable, and not an end, we skip and move on
|
||||||
|
if (unreachableUntilNextEndCount > 0 && v !is Node.Instr.End) {
|
||||||
|
// If it's a block, we increase it because we'll see another end
|
||||||
|
return@insnApply if (v is Node.Instr.Args.Type) unreachable(unreachableUntilNextEndCount + 1) else nop()
|
||||||
|
}
|
||||||
|
when (v) {
|
||||||
|
is Node.Instr.Nop, is Node.Instr.Block, is Node.Instr.Loop -> nop()
|
||||||
|
is Node.Instr.If, is Node.Instr.BrIf -> popI32()
|
||||||
|
is Node.Instr.Return -> (func?.type?.ret?.let { pop(it) } ?: nop()) + unreachable(1)
|
||||||
|
is Node.Instr.Unreachable -> unreachable(1)
|
||||||
|
is Node.Instr.End, is Node.Instr.Else -> {
|
||||||
|
// Put back what was before the last block and add the block's type
|
||||||
|
// Go backwards to find the starting block
|
||||||
|
var currDepth = 0
|
||||||
|
val found = insnApplies.findLast {
|
||||||
|
when (it.insn) {
|
||||||
|
is Node.Instr.End -> { currDepth++; false }
|
||||||
|
is Node.Instr.Args.Type -> if (currDepth > 0) { currDepth--; false } else true
|
||||||
|
else -> false
|
||||||
|
}
|
||||||
|
}?.takeIf {
|
||||||
|
// When it's else, needs to be if
|
||||||
|
v !is Node.Instr.Else || it.insn is Node.Instr.If
|
||||||
|
}
|
||||||
|
val changes = when {
|
||||||
|
found != null && found.insn is Node.Instr.Args.Type &&
|
||||||
|
found.stackAtBeginning != null && this != null -> {
|
||||||
|
// Pop everything from before the block's start, then push if necessary...
|
||||||
|
// The If block includes an int at the beginning we must not include when subtracting
|
||||||
|
var preBlockStackSize = found.stackAtBeginning.size
|
||||||
|
if (found.insn is Node.Instr.If) preBlockStackSize--
|
||||||
|
val popped =
|
||||||
|
if (unreachableUntilNextEndCount > 1) nop()
|
||||||
|
else (0 until (size - preBlockStackSize)).flatMap { pop() }
|
||||||
|
// Only push if this is not an else
|
||||||
|
val pushed =
|
||||||
|
if (unreachableUntilNextEndCount > 1 || v is Node.Instr.Else) nop()
|
||||||
|
else (found.insn.type?.let { push(it) } ?: nop())
|
||||||
|
popped + pushed
|
||||||
|
}
|
||||||
|
strict -> error("Unable to find starting block for end")
|
||||||
|
else -> nop()
|
||||||
|
}
|
||||||
|
if (unreachableUntilNextEndCount > 0) changes + unreachable(unreachableUntilNextEndCount - 1)
|
||||||
|
else changes
|
||||||
|
}
|
||||||
|
is Node.Instr.Br -> unreachable(v.relativeDepth + 1)
|
||||||
|
is Node.Instr.BrTable -> popI32() + unreachable(1)
|
||||||
|
is Node.Instr.Call -> (callFuncTypeOverride ?: func(v.index)).let {
|
||||||
|
if (it == null) error("Call func type missing")
|
||||||
|
it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
|
||||||
|
}
|
||||||
|
is Node.Instr.CallIndirect -> (callFuncTypeOverride ?: mod?.mod?.types?.getOrNull(v.index)).let {
|
||||||
|
if (it == null) error("Call func type missing")
|
||||||
|
// We add one for the table index
|
||||||
|
popI32() + it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
|
||||||
|
}
|
||||||
|
is Node.Instr.Drop -> pop()
|
||||||
|
is Node.Instr.Select -> popI32() + pop().let { it + pop(it.first().type) + push(it.first().type) }
|
||||||
|
is Node.Instr.GetLocal -> push(local(v.index))
|
||||||
|
is Node.Instr.SetLocal -> pop(local(v.index))
|
||||||
|
is Node.Instr.TeeLocal -> local(v.index).let { pop(it) + push(it) }
|
||||||
|
is Node.Instr.GetGlobal -> push(global(v.index))
|
||||||
|
is Node.Instr.SetGlobal -> pop(global(v.index))
|
||||||
|
is Node.Instr.I32Load, is Node.Instr.I32Load8S, is Node.Instr.I32Load8U,
|
||||||
|
is Node.Instr.I32Load16U, is Node.Instr.I32Load16S -> popI32() + pushI32()
|
||||||
|
is Node.Instr.I64Load, is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U,
|
||||||
|
is Node.Instr.I64Load16S, is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> popI32() + pushI64()
|
||||||
|
is Node.Instr.F32Load -> popI32() + pushF32()
|
||||||
|
is Node.Instr.F64Load -> popI32() + pushF64()
|
||||||
|
is Node.Instr.I32Store, is Node.Instr.I32Store8, is Node.Instr.I32Store16 -> popI32() + popI32()
|
||||||
|
is Node.Instr.I64Store, is Node.Instr.I64Store8,
|
||||||
|
is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI64() + popI32()
|
||||||
|
is Node.Instr.F32Store -> popF32() + popI32()
|
||||||
|
is Node.Instr.F64Store -> popF64() + popI32()
|
||||||
|
is Node.Instr.MemorySize -> pushI32()
|
||||||
|
is Node.Instr.MemoryGrow -> popI32() + pushI32()
|
||||||
|
is Node.Instr.I32Const -> pushI32()
|
||||||
|
is Node.Instr.I64Const -> pushI64()
|
||||||
|
is Node.Instr.F32Const -> pushF32()
|
||||||
|
is Node.Instr.F64Const -> pushF64()
|
||||||
|
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
|
||||||
|
is Node.Instr.I32DivU, is Node.Instr.I32RemS, is Node.Instr.I32RemU, is Node.Instr.I32And,
|
||||||
|
is Node.Instr.I32Or, is Node.Instr.I32Xor, is Node.Instr.I32Shl, is Node.Instr.I32ShrS,
|
||||||
|
is Node.Instr.I32ShrU, is Node.Instr.I32Rotl, is Node.Instr.I32Rotr, is Node.Instr.I32Eq,
|
||||||
|
is Node.Instr.I32Ne, is Node.Instr.I32LtS, is Node.Instr.I32LeS, is Node.Instr.I32LtU,
|
||||||
|
is Node.Instr.I32LeU, is Node.Instr.I32GtS, is Node.Instr.I32GeS, is Node.Instr.I32GtU,
|
||||||
|
is Node.Instr.I32GeU -> popI32() + popI32() + pushI32()
|
||||||
|
is Node.Instr.I32Clz, is Node.Instr.I32Ctz, is Node.Instr.I32Popcnt,
|
||||||
|
is Node.Instr.I32Eqz -> popI32() + pushI32()
|
||||||
|
is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS,
|
||||||
|
is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And,
|
||||||
|
is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS,
|
||||||
|
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr -> popI64() + popI64() + pushI64()
|
||||||
|
is Node.Instr.I64Eq, is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS,
|
||||||
|
is Node.Instr.I64LtU, is Node.Instr.I64LeU, is Node.Instr.I64GtS,
|
||||||
|
is Node.Instr.I64GeS, is Node.Instr.I64GtU, is Node.Instr.I64GeU -> popI64() + popI64() + pushI32()
|
||||||
|
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt -> popI64() + pushI64()
|
||||||
|
is Node.Instr.I64Eqz -> popI64() + pushI32()
|
||||||
|
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
|
||||||
|
is Node.Instr.F32Min, is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32()
|
||||||
|
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
|
||||||
|
is Node.Instr.F32Gt, is Node.Instr.F32Ge -> popF32() + popF32() + pushI32()
|
||||||
|
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor,
|
||||||
|
is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> popF32() + pushF32()
|
||||||
|
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
|
||||||
|
is Node.Instr.F64Min, is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64()
|
||||||
|
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
|
||||||
|
is Node.Instr.F64Gt, is Node.Instr.F64Ge -> popF64() + popF64() + pushI32()
|
||||||
|
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor,
|
||||||
|
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + pushF64()
|
||||||
|
is Node.Instr.I32WrapI64 -> popI64() + pushI32()
|
||||||
|
is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
|
||||||
|
is Node.Instr.I32ReinterpretF32 -> popF32() + pushI32()
|
||||||
|
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popF64() + pushI32()
|
||||||
|
is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI32() + pushI64()
|
||||||
|
is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popF32() + pushI64()
|
||||||
|
is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64,
|
||||||
|
is Node.Instr.I64ReinterpretF64 -> popF64() + pushI64()
|
||||||
|
is Node.Instr.F32ConvertSI32, is Node.Instr.F32ConvertUI32,
|
||||||
|
is Node.Instr.F32ReinterpretI32 -> popI32() + pushF32()
|
||||||
|
is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64 -> popI64() + pushF32()
|
||||||
|
is Node.Instr.F32DemoteF64 -> popF64() + pushF32()
|
||||||
|
is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popI32() + pushF64()
|
||||||
|
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64,
|
||||||
|
is Node.Instr.F64ReinterpretI64 -> popI64() + pushF64()
|
||||||
|
is Node.Instr.F64PromoteF32 -> popF32() + pushF64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun insnApply(v: Node.Instr, fn: MutableList<Node.Type.Value>?.() -> List<InsnApplyResponse>): Stack {
|
||||||
|
val mutStack = current?.toMutableList()
|
||||||
|
val applyResp = mutStack.fn()
|
||||||
|
val newUnreachable = (applyResp.find { it is Unreachable } as? Unreachable)?.untilEndCount
|
||||||
|
return copy(
|
||||||
|
current = mutStack,
|
||||||
|
insnApplies = insnApplies + InsnApply(
|
||||||
|
insn = v,
|
||||||
|
stackAtBeginning = current,
|
||||||
|
stackChanges = applyResp.mapNotNull { it as? StackChange },
|
||||||
|
unreachableUntilEndCount = newUnreachable ?: unreachableUntilNextEndCount
|
||||||
|
),
|
||||||
|
unreachableUntilNextEndCount = newUnreachable ?: unreachableUntilNextEndCount
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun unreachable(untilEndCount: Int) = listOf(Unreachable(untilEndCount))
|
||||||
|
protected fun local(index: Int) = func?.let {
|
||||||
|
it.type.params.getOrNull(index) ?: it.locals.getOrNull(index - it.type.params.size)
|
||||||
|
}
|
||||||
|
protected fun global(index: Int) = mod?.let {
|
||||||
|
it.importGlobals.getOrNull(index)?.type?.contentType ?:
|
||||||
|
it.mod.globals.getOrNull(index - it.importGlobals.size)?.type?.contentType
|
||||||
|
}
|
||||||
|
protected fun func(index: Int) = mod?.let {
|
||||||
|
it.importFuncs.getOrNull(index)?.typeIndex?.let { i -> it.mod.types.getOrNull(i) } ?:
|
||||||
|
it.mod.funcs.getOrNull(index - it.importFuncs.size)?.type
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun nop() = emptyList<StackChange>()
|
||||||
|
protected fun MutableList<Node.Type.Value>?.popType(expecting: Node.Type.Value? = null) =
|
||||||
|
this?.takeIf {
|
||||||
|
it.isNotEmpty().also {
|
||||||
|
require(!strict || it) { "Expected $expecting got empty" }
|
||||||
|
}
|
||||||
|
}?.let {
|
||||||
|
removeAt(size - 1).takeIf { actual -> (expecting == null || actual == expecting).also {
|
||||||
|
require(!strict || it) { "Expected $expecting got $actual" }
|
||||||
|
} }
|
||||||
|
} ?: expecting
|
||||||
|
protected fun MutableList<Node.Type.Value>?.pop(expecting: Node.Type.Value? = null) =
|
||||||
|
listOf(StackChange(popType(expecting), true))
|
||||||
|
|
||||||
|
protected fun MutableList<Node.Type.Value>?.popI32() = pop(Node.Type.Value.I32)
|
||||||
|
protected fun MutableList<Node.Type.Value>?.popI64() = pop(Node.Type.Value.I64)
|
||||||
|
protected fun MutableList<Node.Type.Value>?.popF32() = pop(Node.Type.Value.F32)
|
||||||
|
protected fun MutableList<Node.Type.Value>?.popF64() = pop(Node.Type.Value.F64)
|
||||||
|
|
||||||
|
protected fun MutableList<Node.Type.Value>?.push(type: Node.Type.Value? = null) =
|
||||||
|
listOf(StackChange(type, false)).also { if (this != null && type != null) add(type) }
|
||||||
|
protected fun MutableList<Node.Type.Value>?.pushI32() = push(Node.Type.Value.I32)
|
||||||
|
protected fun MutableList<Node.Type.Value>?.pushI64() = push(Node.Type.Value.I64)
|
||||||
|
protected fun MutableList<Node.Type.Value>?.pushF32() = push(Node.Type.Value.F32)
|
||||||
|
protected fun MutableList<Node.Type.Value>?.pushF64() = push(Node.Type.Value.F64)
|
||||||
|
|
||||||
|
data class InsnApply(
|
||||||
|
val insn: Node.Instr,
|
||||||
|
val stackAtBeginning: List<Node.Type.Value>?,
|
||||||
|
val stackChanges: List<StackChange>,
|
||||||
|
val unreachableUntilEndCount: Int
|
||||||
|
)
|
||||||
|
|
||||||
|
protected interface InsnApplyResponse
|
||||||
|
|
||||||
|
data class StackChange(
|
||||||
|
val type: Node.Type.Value?,
|
||||||
|
val pop: Boolean
|
||||||
|
) : InsnApplyResponse
|
||||||
|
|
||||||
|
data class Unreachable(
|
||||||
|
val untilEndCount: Int
|
||||||
|
) : InsnApplyResponse
|
||||||
|
|
||||||
|
class CachedModule(val mod: Node.Module) {
|
||||||
|
val importFuncs by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Func } }
|
||||||
|
val importGlobals by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Global } }
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun walkStrict(mod: Node.Module, func: Node.Func, afterInsn: ((Stack, Node.Instr) -> Unit)? = null) =
|
||||||
|
func.instructions.fold(Stack(
|
||||||
|
mod = CachedModule(mod),
|
||||||
|
func = func,
|
||||||
|
current = emptyList(),
|
||||||
|
strict = true
|
||||||
|
)) { stack, insn -> stack.next(insn).also { afterInsn?.invoke(it, insn) } }.also { stack ->
|
||||||
|
// We expect to be in an unreachable state at the end or have the single return value on the stack
|
||||||
|
if (stack.unreachableUntilNextEndCount == 0) {
|
||||||
|
val expectedStack = (func.type.ret?.let { listOf(it) } ?: emptyList())
|
||||||
|
require(expectedStack == stack.current) {
|
||||||
|
"Expected end to be $expectedStack, got ${stack.current}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun stackChanges(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
|
||||||
|
Stack().next(v, callFuncType).insnApplies.last().stackChanges
|
||||||
|
fun stackChanges(mod: CachedModule, func: Node.Func, v: Node.Instr) =
|
||||||
|
Stack(mod, func).next(v).insnApplies.last().stackChanges
|
||||||
|
fun stackDiff(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
|
||||||
|
stackChanges(v, callFuncType).sumBy { if (it.pop) -1 else 1 }
|
||||||
|
fun stackDiff(mod: CachedModule, func: Node.Func, v: Node.Instr) =
|
||||||
|
stackChanges(mod, func, v).sumBy { if (it.pop) -1 else 1 }
|
||||||
|
}
|
||||||
|
}
|
183
compiler/src/main/kotlin/asmble/ast/opt/SplitLargeFunc.kt
Normal file
183
compiler/src/main/kotlin/asmble/ast/opt/SplitLargeFunc.kt
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
package asmble.ast.opt
|
||||||
|
|
||||||
|
import asmble.ast.Node
|
||||||
|
import asmble.ast.Stack
|
||||||
|
|
||||||
|
// This is a naive implementation that just grabs adjacent sets of restricted insns and breaks the one that will save
|
||||||
|
// the most instructions off into its own function.
|
||||||
|
open class SplitLargeFunc(
|
||||||
|
val minSetLength: Int = 5,
|
||||||
|
val maxSetLength: Int = 40,
|
||||||
|
val maxParamCount: Int = 30
|
||||||
|
) {
|
||||||
|
|
||||||
|
// Null if no replacement. Second value is number of instructions saved. fnIndex must map to actual func,
|
||||||
|
// not imported one.
|
||||||
|
fun apply(mod: Node.Module, fnIndex: Int): Pair<Node.Module, Int>? {
|
||||||
|
// Get the func
|
||||||
|
val importFuncCount = mod.imports.count { it.kind is Node.Import.Kind.Func }
|
||||||
|
val actualFnIndex = fnIndex - importFuncCount
|
||||||
|
val func = mod.funcs.getOrElse(actualFnIndex) {
|
||||||
|
error("Unable to find non-import func at $fnIndex (actual $actualFnIndex)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just take the best pattern and apply it
|
||||||
|
val newFuncIndex = importFuncCount + mod.funcs.size
|
||||||
|
return commonPatterns(mod, func).firstOrNull()?.let { pattern ->
|
||||||
|
// Name it as <funcname>$splitN (n is num just to disambiguate) if names are part of the mod
|
||||||
|
val newName = mod.names?.funcNames?.get(fnIndex)?.let {
|
||||||
|
"$it\$split".let { it + mod.names.funcNames.count { (_, v) -> v.startsWith(it) } }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Go over every replacement in reverse, changing the instructions to our new set
|
||||||
|
val newInsns = pattern.replacements.foldRight(func.instructions) { repl, insns ->
|
||||||
|
insns.take(repl.range.start) +
|
||||||
|
repl.preCallConsts +
|
||||||
|
Node.Instr.Call(newFuncIndex) +
|
||||||
|
insns.drop(repl.range.endInclusive + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the module w/ the new function, it's new name, and the insns saved
|
||||||
|
mod.copy(
|
||||||
|
funcs = mod.funcs.toMutableList().also {
|
||||||
|
it[actualFnIndex] = func.copy(instructions = newInsns)
|
||||||
|
} + pattern.newFunc,
|
||||||
|
names = mod.names?.copy(funcNames = mod.names.funcNames.toMutableMap().also {
|
||||||
|
it[newFuncIndex] = newName!!
|
||||||
|
})
|
||||||
|
) to pattern.insnsSaved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Results are by most insns saved. There can be overlap across patterns but never within a single pattern.
|
||||||
|
fun commonPatterns(mod: Node.Module, fn: Node.Func): List<CommonPattern> {
|
||||||
|
// Walk the stack for validation needs
|
||||||
|
val stack = Stack.walkStrict(mod, fn)
|
||||||
|
|
||||||
|
// Let's grab sets of insns that qualify. In this naive impl, in order to qualify the insn set needs to
|
||||||
|
// only have a certain set of insns that can be broken off. It can also only change the stack by 0 or 1
|
||||||
|
// value while never dipping below the starting stack. We also store the index they started at.
|
||||||
|
var insnSets = emptyList<InsnSet>()
|
||||||
|
// Pair in fold keyed by insn index
|
||||||
|
fn.instructions.foldIndexed(null as List<Pair<Int, Node.Instr>>?) { index, lastInsns, insn ->
|
||||||
|
if (!insn.canBeMoved) null else (lastInsns ?: emptyList()).plus(index to insn).also { fullNewInsnSet ->
|
||||||
|
// Get all final instructions between min and max size and with allowed param count (i.e. const count)
|
||||||
|
val trailingInsnSet = fullNewInsnSet.takeLast(maxSetLength)
|
||||||
|
|
||||||
|
// Get all instructions between the min and max
|
||||||
|
insnSets += (minSetLength..maxSetLength).
|
||||||
|
asSequence().
|
||||||
|
flatMap { trailingInsnSet.asSequence().windowed(it) }.
|
||||||
|
filter { it.count { it.second is Node.Instr.Args.Const<*> } <= maxParamCount }.
|
||||||
|
mapNotNull { newIndexedInsnSet ->
|
||||||
|
// Before adding, make sure it qualifies with the stack
|
||||||
|
InsnSet(
|
||||||
|
startIndex = newIndexedInsnSet.first().first,
|
||||||
|
insns = newIndexedInsnSet.map { it.second },
|
||||||
|
valueAddedToStack = null
|
||||||
|
).withStackValueIfValid(stack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sort the insn sets by the ones with the most insns
|
||||||
|
insnSets = insnSets.sortedByDescending { it.insns.size }
|
||||||
|
|
||||||
|
// Now let's create replacements for each, keyed by the extracted func
|
||||||
|
val patterns = insnSets.fold(emptyMap<Node.Func, List<Replacement>>()) { map, insnSet ->
|
||||||
|
insnSet.extractCommonFunc().let { (func, replacement) ->
|
||||||
|
val existingReplacements = map.getOrDefault(func, emptyList())
|
||||||
|
// Ignore if there is any overlap
|
||||||
|
if (existingReplacements.any(replacement::overlaps)) map
|
||||||
|
else map + (func to existingReplacements.plus(replacement))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now sort the patterns by most insns saved and return
|
||||||
|
return patterns.map { (k, v) ->
|
||||||
|
CommonPattern(k, v.sortedBy { it.range.first })
|
||||||
|
}.sortedByDescending { it.insnsSaved }
|
||||||
|
}
|
||||||
|
|
||||||
|
val Node.Instr.canBeMoved get() =
|
||||||
|
// No blocks
|
||||||
|
this !is Node.Instr.Block && this !is Node.Instr.Loop && this !is Node.Instr.If &&
|
||||||
|
this !is Node.Instr.Else && this !is Node.Instr.End &&
|
||||||
|
// No breaks
|
||||||
|
this !is Node.Instr.Br && this !is Node.Instr.BrIf && this !is Node.Instr.BrTable &&
|
||||||
|
// No return
|
||||||
|
this !is Node.Instr.Return &&
|
||||||
|
// No local access
|
||||||
|
this !is Node.Instr.GetLocal && this !is Node.Instr.SetLocal && this !is Node.Instr.TeeLocal
|
||||||
|
|
||||||
|
fun InsnSet.withStackValueIfValid(stack: Stack): InsnSet? {
|
||||||
|
// This makes sure that the stack only changes by at most one item and never dips below its starting val.
|
||||||
|
// If it is invalid, null is returned. If it qualifies and does change 1 value, it is set.
|
||||||
|
|
||||||
|
// First, make sure the stack after the last insn is the same as the first or the same + 1 val
|
||||||
|
val startingStack = stack.insnApplies[startIndex].stackAtBeginning!!
|
||||||
|
val endingStack = stack.insnApplies.getOrNull(startIndex + insns.size)?.stackAtBeginning ?: stack.current!!
|
||||||
|
if (endingStack.size != startingStack.size && endingStack.size != startingStack.size + 1) return null
|
||||||
|
if (endingStack.take(startingStack.size) != startingStack) return null
|
||||||
|
|
||||||
|
// Now, walk the insns and make sure they never pop below the start
|
||||||
|
var stackCounter = 0
|
||||||
|
stack.insnApplies.subList(startIndex, startIndex + insns.size).forEach {
|
||||||
|
it.stackChanges.forEach {
|
||||||
|
stackCounter += if (it.pop) -1 else 1
|
||||||
|
if (stackCounter < 0) return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We're good, now only if the ending stack is one over the start do we have a ret val
|
||||||
|
return copy(
|
||||||
|
valueAddedToStack = endingStack.lastOrNull()?.takeIf { endingStack.size == startingStack.size + 1 }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun InsnSet.extractCommonFunc() =
|
||||||
|
// This extracts a function with constants changed to parameters
|
||||||
|
insns.fold(Pair(
|
||||||
|
Node.Func(Node.Type.Func(params = emptyList(), ret = valueAddedToStack), emptyList(), emptyList()),
|
||||||
|
Replacement(range = startIndex until startIndex + insns.size, preCallConsts = emptyList()))
|
||||||
|
) { (func, repl), insn ->
|
||||||
|
if (insn !is Node.Instr.Args.Const<*>) func.copy(instructions = func.instructions + insn) to repl
|
||||||
|
else func.copy(
|
||||||
|
type = func.type.copy(params = func.type.params + insn.constType),
|
||||||
|
instructions = func.instructions + Node.Instr.GetLocal(func.type.params.size)
|
||||||
|
) to repl.copy(preCallConsts = repl.preCallConsts + insn)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected val Node.Instr.Args.Const<*>.constType get() = when (this) {
|
||||||
|
is Node.Instr.I32Const -> Node.Type.Value.I32
|
||||||
|
is Node.Instr.I64Const -> Node.Type.Value.I64
|
||||||
|
is Node.Instr.F32Const -> Node.Type.Value.F32
|
||||||
|
is Node.Instr.F64Const -> Node.Type.Value.F64
|
||||||
|
else -> error("unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
data class InsnSet(
|
||||||
|
val startIndex: Int,
|
||||||
|
val insns: List<Node.Instr>,
|
||||||
|
val valueAddedToStack: Node.Type.Value?
|
||||||
|
)
|
||||||
|
|
||||||
|
data class Replacement(
|
||||||
|
val range: IntRange,
|
||||||
|
val preCallConsts: List<Node.Instr>
|
||||||
|
) {
|
||||||
|
// Subtract one because there is a call after this
|
||||||
|
val insnsSaved get() = (range.last + 1) - range.first - 1 - preCallConsts.size
|
||||||
|
fun overlaps(o: Replacement) = range.contains(o.range.first) || range.contains(o.range.last) ||
|
||||||
|
o.range.contains(range.first) || o.range.contains(range.last)
|
||||||
|
}
|
||||||
|
|
||||||
|
data class CommonPattern(
|
||||||
|
val newFunc: Node.Func,
|
||||||
|
// In order by earliest replacement first
|
||||||
|
val replacements: List<Replacement>
|
||||||
|
) {
|
||||||
|
// Replacement pieces saved (with one added for the invocation) less new func instructions
|
||||||
|
val insnsSaved get() = replacements.sumBy { it.insnsSaved } - newFunc.instructions.size
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object : SplitLargeFunc()
|
||||||
|
}
|
@ -3,7 +3,7 @@ package asmble.cli
|
|||||||
import asmble.util.Logger
|
import asmble.util.Logger
|
||||||
import kotlin.system.exitProcess
|
import kotlin.system.exitProcess
|
||||||
|
|
||||||
val commands = listOf(Compile, Help, Invoke, Link, Run, Translate)
|
val commands = listOf(Compile, Help, Invoke, Link, Run, SplitFunc, Translate)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Entry point of command line interface.
|
* Entry point of command line interface.
|
||||||
|
146
compiler/src/main/kotlin/asmble/cli/SplitFunc.kt
Normal file
146
compiler/src/main/kotlin/asmble/cli/SplitFunc.kt
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package asmble.cli
|
||||||
|
|
||||||
|
import asmble.ast.Node
|
||||||
|
import asmble.ast.Script
|
||||||
|
import asmble.ast.opt.SplitLargeFunc
|
||||||
|
|
||||||
|
open class SplitFunc : Command<SplitFunc.Args>() {
|
||||||
|
override val name = "split-func"
|
||||||
|
override val desc = "Split a WebAssembly function into two"
|
||||||
|
|
||||||
|
override fun args(bld: Command.ArgsBuilder) = Args(
|
||||||
|
inFile = bld.arg(
|
||||||
|
name = "inFile",
|
||||||
|
desc = "The wast or wasm WebAssembly file name. Can be '--' to read from stdin."
|
||||||
|
),
|
||||||
|
funcName = bld.arg(
|
||||||
|
name = "funcName",
|
||||||
|
desc = "The name (or '#' + function space index) of the function to split"
|
||||||
|
),
|
||||||
|
inFormat = bld.arg(
|
||||||
|
name = "inFormat",
|
||||||
|
opt = "in",
|
||||||
|
desc = "Either 'wast' or 'wasm' to describe format.",
|
||||||
|
default = "<use file extension>",
|
||||||
|
lowPriority = true
|
||||||
|
),
|
||||||
|
outFile = bld.arg(
|
||||||
|
name = "outFile",
|
||||||
|
opt = "outFile",
|
||||||
|
desc = "The wast or wasm WebAssembly file name. Can be '--' to write to stdout.",
|
||||||
|
default = "<inFileSansExt.split.wasm or stdout>",
|
||||||
|
lowPriority = true
|
||||||
|
),
|
||||||
|
outFormat = bld.arg(
|
||||||
|
name = "outFormat",
|
||||||
|
opt = "out",
|
||||||
|
desc = "Either 'wast' or 'wasm' to describe format.",
|
||||||
|
default = "<use file extension or wast for stdout>",
|
||||||
|
lowPriority = true
|
||||||
|
),
|
||||||
|
compact = bld.flag(
|
||||||
|
opt = "compact",
|
||||||
|
desc = "If set for wast out format, will be compacted.",
|
||||||
|
lowPriority = true
|
||||||
|
),
|
||||||
|
minInsnSetLength = bld.arg(
|
||||||
|
name = "minInsnSetLength",
|
||||||
|
opt = "minLen",
|
||||||
|
desc = "The minimum number of instructions allowed for the split off function.",
|
||||||
|
default = "5",
|
||||||
|
lowPriority = true
|
||||||
|
).toInt(),
|
||||||
|
maxInsnSetLength = bld.arg(
|
||||||
|
name = "maxInsnSetLength",
|
||||||
|
opt = "maxLen",
|
||||||
|
desc = "The maximum number of instructions allowed for the split off function.",
|
||||||
|
default = "40",
|
||||||
|
lowPriority = true
|
||||||
|
).toInt(),
|
||||||
|
maxNewFuncParamCount = bld.arg(
|
||||||
|
name = "maxNewFuncParamCount",
|
||||||
|
opt = "maxParams",
|
||||||
|
desc = "The maximum number of params allowed for the split off function.",
|
||||||
|
default = "30",
|
||||||
|
lowPriority = true
|
||||||
|
).toInt(),
|
||||||
|
attempts = bld.arg(
|
||||||
|
name = "attempts",
|
||||||
|
opt = "attempts",
|
||||||
|
desc = "The number of attempts to perform.",
|
||||||
|
default = "1",
|
||||||
|
lowPriority = true
|
||||||
|
).toInt()
|
||||||
|
).also { bld.done() }
|
||||||
|
|
||||||
|
override fun run(args: Args) {
|
||||||
|
// Load the mod
|
||||||
|
val translate = Translate().also { it.logger = logger }
|
||||||
|
val inFormat =
|
||||||
|
if (args.inFormat != "<use file extension>") args.inFormat
|
||||||
|
else args.inFile.substringAfterLast('.', "<unknown>")
|
||||||
|
val script = translate.inToAst(args.inFile, inFormat)
|
||||||
|
var mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?: error("Only a single module allowed")
|
||||||
|
|
||||||
|
// Do attempts
|
||||||
|
val splitter = SplitLargeFunc(
|
||||||
|
minSetLength = args.minInsnSetLength,
|
||||||
|
maxSetLength = args.maxInsnSetLength,
|
||||||
|
maxParamCount = args.maxNewFuncParamCount
|
||||||
|
)
|
||||||
|
for (attempt in 0 until args.attempts) {
|
||||||
|
// Find the function
|
||||||
|
var index = mod.names?.funcNames?.toList()?.find { it.second == args.funcName }?.first
|
||||||
|
if (index == null && args.funcName.startsWith('#')) index = args.funcName.drop(1).toInt()
|
||||||
|
val origFunc = index?.let {
|
||||||
|
mod.funcs.getOrNull(it - mod.imports.count { it.kind is Node.Import.Kind.Func })
|
||||||
|
} ?: error("Unable to find func")
|
||||||
|
|
||||||
|
// Split it
|
||||||
|
val results = splitter.apply(mod, index)
|
||||||
|
if (results == null) {
|
||||||
|
logger.warn { "No instructions after attempt $attempt" }
|
||||||
|
break
|
||||||
|
}
|
||||||
|
val (splitMod, insnsSaved) = results
|
||||||
|
val newFunc = splitMod.funcs[index - mod.imports.count { it.kind is Node.Import.Kind.Func }]
|
||||||
|
val splitFunc = splitMod.funcs.last()
|
||||||
|
logger.warn {
|
||||||
|
"Split complete, from func with ${origFunc.instructions.size} insns to a func " +
|
||||||
|
"with ${newFunc.instructions.size} insns + delegated func " +
|
||||||
|
"with ${splitFunc.instructions.size} insns and ${splitFunc.type.params.size} params, " +
|
||||||
|
"saved $insnsSaved insns"
|
||||||
|
}
|
||||||
|
mod = splitMod
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write it
|
||||||
|
val outFile = when {
|
||||||
|
args.outFile != "<inFileSansExt.split.wasm or stdout>" -> args.outFile
|
||||||
|
args.inFile == "--" -> "--"
|
||||||
|
else -> args.inFile.replaceAfterLast('.', "split." + args.inFile.substringAfterLast('.'))
|
||||||
|
}
|
||||||
|
val outFormat = when {
|
||||||
|
args.outFormat != "<use file extension or wast for stdout>" -> args.outFormat
|
||||||
|
outFile == "--" -> "wast"
|
||||||
|
else -> outFile.substringAfterLast('.', "<unknown>")
|
||||||
|
}
|
||||||
|
translate.astToOut(outFile, outFormat, args.compact,
|
||||||
|
Script(listOf(Script.Cmd.Module(mod, mod.names?.moduleName))))
|
||||||
|
}
|
||||||
|
|
||||||
|
data class Args(
|
||||||
|
val inFile: String,
|
||||||
|
val inFormat: String,
|
||||||
|
val funcName: String,
|
||||||
|
val outFile: String,
|
||||||
|
val outFormat: String,
|
||||||
|
val compact: Boolean,
|
||||||
|
val minInsnSetLength: Int,
|
||||||
|
val maxInsnSetLength: Int,
|
||||||
|
val maxNewFuncParamCount: Int,
|
||||||
|
val attempts: Int
|
||||||
|
)
|
||||||
|
|
||||||
|
companion object : SplitFunc()
|
||||||
|
}
|
@ -52,24 +52,7 @@ open class Translate : Command<Translate.Args>() {
|
|||||||
if (args.outFormat != "<use file extension or wast for stdout>") args.outFormat
|
if (args.outFormat != "<use file extension or wast for stdout>") args.outFormat
|
||||||
else if (args.outFile == "--") "wast"
|
else if (args.outFile == "--") "wast"
|
||||||
else args.outFile.substringAfterLast('.', "<unknown>")
|
else args.outFile.substringAfterLast('.', "<unknown>")
|
||||||
val outStream =
|
astToOut(args.outFile, outFormat, args.compact, script)
|
||||||
if (args.outFile == "--") System.out
|
|
||||||
else FileOutputStream(args.outFile)
|
|
||||||
outStream.use { outStream ->
|
|
||||||
when (outFormat) {
|
|
||||||
"wast" -> {
|
|
||||||
val sexprToStr = if (args.compact) SExprToStr.Compact else SExprToStr
|
|
||||||
val sexprs = AstToSExpr.fromScript(script)
|
|
||||||
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
|
|
||||||
}
|
|
||||||
"wasm" -> {
|
|
||||||
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
|
|
||||||
error("Output to WASM requires input be just a single module")
|
|
||||||
AstToBinary.fromModule(ByteWriter.OutputStream(outStream), mod)
|
|
||||||
}
|
|
||||||
else -> error("Unknown out format '$outFormat'")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun inToAst(inFile: String, inFormat: String): Script {
|
fun inToAst(inFile: String, inFormat: String): Script {
|
||||||
@ -93,6 +76,27 @@ open class Translate : Command<Translate.Args>() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun astToOut(outFile: String, outFormat: String, compact: Boolean, script: Script) {
|
||||||
|
val outStream =
|
||||||
|
if (outFile == "--") System.out
|
||||||
|
else FileOutputStream(outFile)
|
||||||
|
outStream.use { outStream ->
|
||||||
|
when (outFormat) {
|
||||||
|
"wast" -> {
|
||||||
|
val sexprToStr = if (compact) SExprToStr.Compact else SExprToStr
|
||||||
|
val sexprs = AstToSExpr.fromScript(script)
|
||||||
|
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
|
||||||
|
}
|
||||||
|
"wasm" -> {
|
||||||
|
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
|
||||||
|
error("Output to WASM requires input be just a single module")
|
||||||
|
AstToBinary.fromModule(ByteWriter.OutputStream(outStream), mod)
|
||||||
|
}
|
||||||
|
else -> error("Unknown out format '$outFormat'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data class Args(
|
data class Args(
|
||||||
val inFile: String,
|
val inFile: String,
|
||||||
val inFormat: String,
|
val inFormat: String,
|
||||||
|
33
compiler/src/test/kotlin/asmble/ast/StackTest.kt
Normal file
33
compiler/src/test/kotlin/asmble/ast/StackTest.kt
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package asmble.ast
|
||||||
|
|
||||||
|
import asmble.SpecTestUnit
|
||||||
|
import asmble.TestBase
|
||||||
|
import org.junit.Test
|
||||||
|
import org.junit.runner.RunWith
|
||||||
|
import org.junit.runners.Parameterized
|
||||||
|
|
||||||
|
@RunWith(Parameterized::class)
|
||||||
|
class StackTest(val unit: SpecTestUnit) : TestBase() {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testStack() {
|
||||||
|
// If it's not a module expecting an error, we'll try to walk the stack on each function
|
||||||
|
unit.script.commands.mapNotNull { it as? Script.Cmd.Module }.forEach { mod ->
|
||||||
|
mod.module.funcs.filter { it.instructions.isNotEmpty() }.forEach { func ->
|
||||||
|
debug { "Func: ${func.type}" }
|
||||||
|
var indexCount = 0
|
||||||
|
Stack.walkStrict(mod.module, func) { stack, insn ->
|
||||||
|
debug { " After $insn (next: ${func.instructions.getOrNull(++indexCount)}, " +
|
||||||
|
"unreach depth: ${stack.unreachableUntilNextEndCount})" }
|
||||||
|
debug { " " + stack.current }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
// Only tests that shouldn't fail
|
||||||
|
@JvmStatic @Parameterized.Parameters(name = "{0}")
|
||||||
|
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }//.filter { it.name == "loop" }
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
package asmble.ast.opt
|
||||||
|
|
||||||
|
import asmble.TestBase
|
||||||
|
import asmble.ast.Node
|
||||||
|
import asmble.compile.jvm.AstToAsm
|
||||||
|
import asmble.compile.jvm.ClsContext
|
||||||
|
import asmble.run.jvm.ScriptContext
|
||||||
|
import org.junit.Test
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.util.*
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class SplitLargeFuncTest : TestBase() {
|
||||||
|
@Test
|
||||||
|
fun testSplitLargeFunc() {
|
||||||
|
// We're going to make a large function that does some addition and then stores in mem
|
||||||
|
val ctx = ClsContext(
|
||||||
|
packageName = "test",
|
||||||
|
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
|
||||||
|
logger = logger,
|
||||||
|
mod = Node.Module(
|
||||||
|
memories = listOf(Node.Type.Memory(Node.ResizableLimits(initial = 2, maximum = 2))),
|
||||||
|
funcs = listOf(Node.Func(
|
||||||
|
type = Node.Type.Func(params = emptyList(), ret = null),
|
||||||
|
locals = emptyList(),
|
||||||
|
instructions = (0 until 501).flatMap {
|
||||||
|
listOf<Node.Instr>(
|
||||||
|
Node.Instr.I32Const(it * 4),
|
||||||
|
// Let's to i * (i = 1)
|
||||||
|
Node.Instr.I32Const(it),
|
||||||
|
Node.Instr.I32Const(it - 1),
|
||||||
|
Node.Instr.I32Mul,
|
||||||
|
Node.Instr.I32Store(0, 0)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)),
|
||||||
|
names = Node.NameSection(
|
||||||
|
moduleName = null,
|
||||||
|
funcNames = mapOf(0 to "someFunc"),
|
||||||
|
localNames = emptyMap()
|
||||||
|
),
|
||||||
|
exports = listOf(
|
||||||
|
Node.Export("memory", Node.ExternalKind.MEMORY, 0),
|
||||||
|
Node.Export("someFunc", Node.ExternalKind.FUNCTION, 0)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
// Compile it
|
||||||
|
AstToAsm.fromModule(ctx)
|
||||||
|
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
|
||||||
|
val inst = cls.newInstance()
|
||||||
|
// Run someFunc
|
||||||
|
cls.getMethod("someFunc").invoke(inst)
|
||||||
|
// Get the memory out
|
||||||
|
val mem = cls.getMethod("getMemory").invoke(inst) as ByteBuffer
|
||||||
|
// Read out the mem values
|
||||||
|
(0 until 501).forEach { assertEquals(it * (it - 1), mem.getInt(it * 4)) }
|
||||||
|
|
||||||
|
// Now split it
|
||||||
|
val (splitMod, insnsSaved) = SplitLargeFunc.apply(ctx.mod, 0) ?: error("Nothing could be split")
|
||||||
|
// Count insns and confirm it is as expected
|
||||||
|
val origInsnCount = ctx.mod.funcs.sumBy { it.instructions.size }
|
||||||
|
val newInsnCount = splitMod.funcs.sumBy { it.instructions.size }
|
||||||
|
assertEquals(origInsnCount - newInsnCount, insnsSaved)
|
||||||
|
// Compile it
|
||||||
|
val splitCtx = ClsContext(
|
||||||
|
packageName = "test",
|
||||||
|
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
|
||||||
|
logger = logger,
|
||||||
|
mod = splitMod
|
||||||
|
)
|
||||||
|
AstToAsm.fromModule(splitCtx)
|
||||||
|
val splitCls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(splitCtx)
|
||||||
|
val splitInst = splitCls.newInstance()
|
||||||
|
// Run someFunc
|
||||||
|
splitCls.getMethod("someFunc").invoke(splitInst)
|
||||||
|
// Get the memory out and compare it
|
||||||
|
val splitMem = splitCls.getMethod("getMemory").invoke(splitInst) as ByteBuffer
|
||||||
|
assertEquals(mem, splitMem)
|
||||||
|
// Dump some info
|
||||||
|
logger.debug {
|
||||||
|
val orig = ctx.mod.funcs.first()
|
||||||
|
val (new, split) = splitMod.funcs.let { it.first() to it.last() }
|
||||||
|
"Split complete, from single func with ${orig.instructions.size} insns to func " +
|
||||||
|
"with ${new.instructions.size} insns + delegated func " +
|
||||||
|
"with ${split.instructions.size} insns and ${split.type.params.size} params"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package asmble.io
|
package asmble.io
|
||||||
|
|
||||||
import asmble.SpecTestUnit
|
import asmble.SpecTestUnit
|
||||||
|
import asmble.TestBase
|
||||||
import asmble.ast.Node
|
import asmble.ast.Node
|
||||||
import asmble.ast.Script
|
import asmble.ast.Script
|
||||||
import asmble.util.Logger
|
import asmble.util.Logger
|
||||||
@ -13,12 +14,10 @@ import java.io.ByteArrayOutputStream
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@RunWith(Parameterized::class)
|
@RunWith(Parameterized::class)
|
||||||
class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) {
|
class IoTest(val unit: SpecTestUnit) : TestBase() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testIo() {
|
fun testIo() {
|
||||||
// Ignore things that are supposed to fail
|
|
||||||
if (unit.shouldFail) return
|
|
||||||
// Go from the AST to binary then back to AST then back to binary and confirm values are as expected
|
// Go from the AST to binary then back to AST then back to binary and confirm values are as expected
|
||||||
val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also {
|
val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also {
|
||||||
trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) }
|
trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) }
|
||||||
@ -46,7 +45,8 @@ class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO)
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
// Only tests that shouldn't fail
|
||||||
@JvmStatic @Parameterized.Parameters(name = "{0}")
|
@JvmStatic @Parameterized.Parameters(name = "{0}")
|
||||||
fun data() = SpecTestUnit.allUnits
|
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
10
examples/go-simple/simple.go
Normal file
10
examples/go-simple/simple.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Printf("Args: %v", os.Args)
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
rootProject.name = 'asmble'
|
rootProject.name = 'asmble'
|
||||||
include 'annotations',
|
include 'annotations',
|
||||||
'compiler',
|
'compiler',
|
||||||
// 'examples:rust-regex', // todo will be enabled when the problem with string max size will be solved
|
'examples:c-simple',
|
||||||
|
'examples:go-simple',
|
||||||
|
'examples:rust-regex',
|
||||||
'examples:rust-simple',
|
'examples:rust-simple',
|
||||||
'examples:rust-string'
|
'examples:rust-string'
|
Reference in New Issue
Block a user