76 Commits

Author SHA1 Message Date
58cf836b76 Disable Go and Rust examples and up project version 2018-10-01 13:40:38 +04:00
56c2c8d672 Fix lateinit error with logger for wasm files 2018-10-01 13:39:08 +04:00
da70c9fca4 Fix for "lateinit property logger has not been initialized" 2018-08-28 16:50:37 +04:00
a1a5563367 Publishing the fork to binTray 2018-08-14 17:32:10 +04:00
e489e7c889 Publishing the fork to binTray 2018-08-14 17:22:30 +04:00
c1391b2701 Expected but isn't working version 2018-08-14 12:11:48 +04:00
6b28c5a93b merge original-master to master 2018-08-10 13:10:28 +04:00
9dc4112512 Add beginning of CLI for func split issue #19 2018-07-29 01:18:39 -05:00
01d89947e5 Complete main func splitting impl with test 2018-07-28 23:28:50 -05:00
4373676448 Begin test for large split func on issue #19 2018-07-27 17:03:55 -05:00
94953a4ada More work on large func splitter for issue #19 2018-07-27 15:16:37 -05:00
472579020a Early code for WASM function splitting for issue #19 2018-07-27 02:56:30 -05:00
75de1d76e3 Finish stack walker 2018-07-27 01:15:48 -05:00
67e914d683 Begin abstracting stack walking for issue #19 2018-07-26 17:17:43 -05:00
4bc12f4b94 Begin work on Go examples for #14 2018-07-26 17:16:11 -05:00
1990f46743 merge asmble-master to master 2018-07-26 10:49:28 +04:00
559df45f09 Merge pull request #1 from cretz/master
Fetch master changes
2018-07-26 10:34:42 +04:00
73862e9bc9 Chunk data sections even smaller per #18 and update README explanation 2018-07-26 00:27:17 -05:00
1127b61eb5 Update README explanation about string const max 2018-07-26 00:08:40 -05:00
a66c05ad4a Support large data sections. Fixes #18 2018-07-26 00:05:18 -05:00
6786350f53 Fixed to set proper stack diff size for store insns 2018-07-26 00:03:49 -05:00
706da0d486 Removed prefixed dollar sign from sexpr names and add export names to dedupe check for issue #17 2018-07-25 16:30:48 -05:00
1d5c1e527a Emit given names in compiled class. Fixes #17 2018-07-25 15:59:27 -05:00
1430bf48a6 Support names in converters for issue #17 2018-07-25 15:19:25 -05:00
96febbecd5 Beginning of name support for issue #17 2018-07-25 12:57:54 -05:00
80a8a1fbb9 Temporary disable rust-regex example 2018-07-25 10:35:37 +04:00
dd72c7124c Fix Rust Simple and Rust String examples 2018-07-25 10:21:38 +04:00
c04a3c4a9b Add some java docs 2018-07-24 11:00:41 +04:00
51520ac07d Add some java docs 2018-07-23 12:52:30 +04:00
3c25b40c40 Maven fetch instructions 2018-07-20 16:15:26 -05:00
96458bdec7 Maven publishing support. Fixes #15 2018-07-20 15:59:03 -05:00
dd33676e50 README update for mutable globals issue #16 2018-07-20 15:05:57 -05:00
e9364574a3 Changed resizing to memory_grow for default max mem pages due to https://github.com/WebAssembly/spec/pull/808 2018-07-20 14:59:04 -05:00
73e6b5769a Support mutable globals. Fixes #16 2018-07-20 14:43:54 -05:00
9d87ce440f Add "too many locals" error 2018-07-20 10:20:03 -05:00
51bc8008e1 Update to latest spec 2018-07-20 10:13:27 -05:00
198c521dd7 Update Kotlin and minor README tweak 2018-07-20 09:23:27 -05:00
97660de6ba Add some java docs 2018-07-20 12:52:21 +04:00
cee7a86773 Fix rust-regex example 2018-07-19 17:02:25 +04:00
cfa4a35af1 remove C example 2018-07-19 09:14:01 +04:00
f24342959d Update spec and two related changes (detailed)
* Added LEB128 validation (ref: https://github.com/WebAssembly/spec/pull/750)
* Rename memory instructions (ref: https://github.com/WebAssembly/spec/pull/720)
2018-05-07 15:39:31 -05:00
368ab300fa Update Kotlin 2018-05-07 11:43:34 -05:00
94ba46eca9 Set version 0.2.0 2018-03-02 17:58:40 -06:00
7229ee6eb6 Remove emscripten runtime originally for issue #7 2018-03-02 13:55:17 -06:00
2dddd90d2f Update WASM spec tests and make several fixes to conform 2018-03-02 04:27:34 -06:00
e79cc2e36b Update Kotlin 2018-03-01 21:18:32 -06:00
3cb439e887 First C example for issue #11 2017-12-15 14:43:21 -06:00
923946f66f Add rustup update to Rust simple example README 2017-12-06 10:05:11 -06:00
ffbf6a5288 Clarify Rust example build prereqs 2017-12-06 10:01:52 -06:00
bc980b1e81 Example README word fix 2017-12-06 03:38:59 -06:00
7c61edf257 Fix corpus size number in example README 2017-12-06 03:27:02 -06:00
eaf4137c67 Added Rust regex benchmark. Fixes issue #9. 2017-12-06 03:13:52 -06:00
1418ba86cb More work on Rust regex example.
Issue #9
2017-12-05 23:05:27 -06:00
4febf34e69 Initial skeleton for issue #9 2017-11-29 16:42:12 -06:00
8b51e14c33 Rust string example for issue #9 2017-11-29 13:03:36 -06:00
e51da3116e Minor example README update 2017-11-29 09:45:07 -06:00
ff7c88bf6c First Rust example for issue #9 2017-11-28 17:00:38 -06:00
0c4fb45d79 Update WASM spec tests and change implementation to conform 2017-10-12 14:52:31 -05:00
cb8470f54f Upgrade kotlin and minor warning suppression 2017-10-10 16:14:27 -05:00
3e912b2b15 Begin linker dev for issue #8 2017-06-08 11:55:03 -05:00
43333edfd0 Updated wasm spec and enforced UTF-8 validation on binary strings 2017-05-24 17:17:37 -05:00
e9cdfc3b0f A few more tests for emscripten support for #7 2017-04-26 21:45:20 -05:00
b4140c8189 Initial work to support emscripten runtime for issue #7 2017-04-26 15:35:34 -05:00
d94b5ce898 Moved compiler to subproject 2017-04-24 18:49:54 -05:00
5430e19a2b Updated WASM spec tests to latest 2017-04-24 17:49:22 -05:00
706c76a5cd Minor cleanup of fix #4 for stack injection 2017-04-23 18:04:12 -05:00
7cffd74670 Fix stack injector to not inject into inner blocks. Fixes #4. 2017-04-23 17:53:41 -05:00
2eb4506237 Fixed stack diff w/ sqrt. Fixes #5. 2017-04-23 16:17:11 -05:00
a73e719f24 Fix block str writes. Fixes #3 2017-04-23 15:55:43 -05:00
b367c59ba3 Minor README note about assembleDist 2017-04-23 03:36:12 -05:00
febb07b8b2 Stopped the extra pop on get/set global. Fixes #2. 2017-04-23 03:35:03 -05:00
e2996212e9 Bit more work supporting emscripten emulation 2017-04-21 15:00:27 -05:00
5890a1cd7c Beginning of emscripten emulation 2017-04-20 23:35:02 -05:00
da1d94dc9e Minor cleanup of TODOs and warnings 2017-04-19 19:37:07 -05:00
a9dc8ddd77 Fixes #1. Fix large mem data, copysign stack count, and reworked leftover mem instance support 2017-04-19 00:33:32 -05:00
132b50772d Add 0.1.0 info to README 2017-04-17 12:40:14 -05:00
117 changed files with 306508 additions and 788 deletions

26
.gitignore vendored
View File

@ -1,6 +1,32 @@
.classpath
.project
.settings
/gradlew
/gradlew.bat
/.gradle
/.idea
/asmble.iml
/build
/gradle
/compiler/bin
/compiler/build
/compiler/out
/annotations/bin
/annotations/build
/annotations/out
/examples/c-simple/bin
/examples/c-simple/build
/examples/go-simple/bin
/examples/go-simple/build
/examples/rust-simple/Cargo.lock
/examples/rust-simple/bin
/examples/rust-simple/build
/examples/rust-simple/target
/examples/rust-string/Cargo.lock
/examples/rust-string/bin
/examples/rust-string/build
/examples/rust-string/target
/examples/rust-regex/Cargo.lock
/examples/rust-regex/bin
/examples/rust-regex/build
/examples/rust-regex/target

4
.gitmodules vendored
View File

@ -1,3 +1,3 @@
[submodule "src/test/resources/spec"]
path = src/test/resources/spec
[submodule "compiler/src/test/resources/spec"]
path = compiler/src/test/resources/spec
url = https://github.com/WebAssembly/spec.git

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2017 Chad Retz
Copyright (c) 2018 Chad Retz
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -32,7 +32,7 @@ The result will be:
70 : i32
Which is how the test harness prints an integer.
Which is how the test harness prints an integer. See the [examples](examples) directory for more examples.
## CLI Usage
@ -183,8 +183,16 @@ JVM languages.
### Getting
The latest tag can be added to your build script via [JitPack](https://jitpack.io). For example,
[here](https://jitpack.io/#cretz/asmble/master-SNAPSHOT) are instructions for using the latest master.
The compiler and annotations are deployed to Maven Central. The compiler is written in Kotlin and can be added as a
Gradle dependency with:
compile 'com.github.cretz.asmble:asmble-compiler:0.3.0'
This is only needed to compile of course, the compiled code has no runtime requirement. The compiled code does include
some annotations (but in Java its ok to have annotations that are not found). If you do want to reflect the annotations,
the annotation library can be added as a Gradle dependency with:
compile 'com.github.cretz.asmble:asmble-annotations:0.3.0'
### Building and Testing
@ -197,6 +205,7 @@ The reason we use recursive is to clone the spec submodule we have embedded at `
`gradle wrapper`. Now the `gradlew` command is available.
To build, run `./gradlew build`. This will run all tests which includes the test suite from the WebAssembly spec.
Running `./gradlew assembleDist` builds the same zip and tar files uploaded to the releases area.
### Library Notes
@ -254,15 +263,16 @@ In the WebAssembly MVP a table is just a set of function pointers. This is store
#### Globals
Globals are stored as fields on the class. A non-import global is simply a field, but an import global is a
`MethodHandle` to the getter (and would be a `MethodHandle` to the setter if mutable globals were supported). Any values
for the globals are set in the constructor.
Globals are stored as fields on the class. A non-import global is simply a field that is final if not mutable. An import
global is a `MethodHandle` to the getter and a `MethodHandle` to the setter if mutable. Any values for the globals are
set in the constructor.
#### Imports
The constructor accepts all imports as params. Memory is imported via a `ByteBuffer` param, then function
imports as `MethodHandle` params, then global imports as `MethodHandle` params, then a `MethodHandle` array param for an
imported table. All of these values are set as fields in the constructor.
imports as `MethodHandle` params, then global imports as `MethodHandle` params (one for getter and another for setter if
mutable), then a `MethodHandle` array param for an imported table. All of these values are set as fields in the
constructor.
#### Exports
@ -329,6 +339,9 @@ simply do normal field access.
Memory operations are done via `ByteBuffer` methods on a little-endian buffer. All operations including unsigned
operations are tailored to use specific existing Java stdlib functions.
As a special optimization, we put the memory instance as a local var if it is accessed a lot in a function. This is
cheaper than constantly fetching the field.
#### Number Operations
Constants are simply `ldc` bytecode ops on the JVM. Comparisons are done via specific bytecodes sometimes combined with
@ -358,9 +371,12 @@ stack (e.g. some places where we do a swap).
Below are some performance and implementation quirks where there is a bit of an impedance mismatch between WebAssembly
and the JVM:
* WebAssembly has a nice data section for byte arrays whereas the JVM does not. Right now we build a byte array from
a bunch of consts at runtime which is multiple operations per byte. This can bloat the class file size, but is quite
fast compared to alternatives such as string constants.
* WebAssembly has a nice data section for byte arrays whereas the JVM does not. Right now we use a single-byte-char
string constant (i.e. ISO-8859 charset). This saves class file size, but this means we call `String::getBytes` on
init to load bytes from the string constant. Due to the JVM using an unsigned 16-bit int as the string constant
length, the maximum byte length is 65536. Since the string constants are stored as UTF-8 constants, they can be up to
four bytes a character. Therefore, we populate memory in data chunks no larger than 16300 (nice round number to make
sure that even in the worse case of 4 bytes per char in UTF-8 view, we're still under the max).
* The JVM makes no guarantees about trailing bits being preserved on NaN floating point representations like WebAssembly
does. This causes some mismatch on WebAssembly tests depending on how the JVM "feels" (I haven't dug into why some
bit patterns stay and some don't when NaNs are passed through methods).
@ -412,6 +428,8 @@ WASM compiled from Rust, C, Java, etc if e.g. they all have their own way of han
definition of an importable set of modules that does all of these things, even if it's in WebIDL. I dunno, maybe the
effort is already there, I haven't really looked.
There is https://github.com/konsoletyper/teavm
**So I can compile something in C via Emscripten and have it run on the JVM with this?**
Yes, but work is required. WebAssembly is lacking any kind of standard library. So Emscripten will either embed it or
@ -430,3 +448,4 @@ Not yet, once source maps get standardized I may revisit.
* Add "link" command that will build an entire JAR out of several WebAssembly files and glue code between them
* Annotations to make it clear what imports are expected
* Compile to JS and native with Kotlin
* Add javax.script (which can give things like a free repl w/ jrunscript)

View File

@ -0,0 +1,11 @@
package asmble.annotation;
import java.lang.annotation.*;
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.TYPE, ElementType.METHOD })
public @interface WasmExport {
String value();
WasmExternalKind kind() default WasmExternalKind.FUNCTION;
}

View File

@ -0,0 +1,5 @@
package asmble.annotation;
public enum WasmExternalKind {
MEMORY, GLOBAL, FUNCTION, TABLE
}

View File

@ -0,0 +1,17 @@
package asmble.annotation;
import java.lang.annotation.*;
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface WasmImport {
String module();
String field();
// The JVM method descriptor of an export that will match this
String desc();
WasmExternalKind kind();
int resizableLimitInitial() default -1;
int resizableLimitMaximum() default -1;
boolean globalSetter() default false;
}

View File

@ -0,0 +1,11 @@
package asmble.annotation;
import java.lang.annotation.*;
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface WasmModule {
String name() default "";
String binary() default "";
}

View File

@ -1,32 +1,55 @@
group 'asmble'
version '0.1.0'
version '0.2.0'
buildscript {
ext.kotlin_version = '1.1.1'
ext.kotlin_version = '1.2.51'
ext.asm_version = '5.2'
repositories {
mavenCentral()
maven {
url "https://plugins.gradle.org/m2/"
}
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
classpath 'me.champeau.gradle:jmh-gradle-plugin:0.4.5'
classpath 'com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4'
}
}
apply plugin: 'java'
apply plugin: 'kotlin'
apply plugin: 'application'
allprojects {
apply plugin: 'java'
group 'com.github.cretz.asmble'
version '0.4.0-fl-fix'
mainClassName = "asmble.cli.MainKt"
repositories {
repositories {
mavenCentral()
}
}
distTar.archiveName = 'asmble.tar'
distZip.archiveName = 'asmble.zip'
project(':annotations') {
javadoc {
options.links 'https://docs.oracle.com/javase/8/docs/api/'
// TODO: change when https://github.com/gradle/gradle/issues/2354 is fixed
options.addStringOption 'Xdoclint:all', '-Xdoclint:-missing'
}
dependencies {
publishSettings(project, 'asmble-annotations', 'Asmble WASM Annotations')
}
project(':compiler') {
apply plugin: 'kotlin'
apply plugin: 'application'
applicationName = "asmble"
mainClassName = "asmble.cli.MainKt"
distTar.archiveName = 'asmble.tar'
distZip.archiveName = 'asmble.zip'
dependencies {
compile project(':annotations')
compile "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
compile "org.ow2.asm:asm-tree:$asm_version"
@ -34,4 +57,203 @@ dependencies {
testCompile 'junit:junit:4.12'
testCompile "org.jetbrains.kotlin:kotlin-test-junit:$kotlin_version"
testCompile "org.ow2.asm:asm-debug-all:$asm_version"
}
publishSettings(project, 'asmble-compiler', 'Asmble WASM Compiler')
}
project(':examples') {
subprojects {
dependencies {
compileOnly project(':compiler')
}
// Go example helpers
task goToWasm {
doFirst {
mkdir 'build'
exec {
def goFileName = fileTree(dir: '.', includes: ['*.go']).files.iterator().next()
environment 'GOOS': 'js', 'GOARCH': 'wasm'
commandLine 'go', 'build', '-o', 'build/lib.wasm', goFileName
}
}
}
task compileGoWasm(type: JavaExec) {
dependsOn goToWasm
classpath configurations.compileClasspath
main = 'asmble.cli.MainKt'
doFirst {
// args 'help', 'compile'
def outFile = 'build/wasm-classes/' + wasmCompiledClassName.replace('.', '/') + '.class'
file(outFile).parentFile.mkdirs()
args 'compile', 'build/lib.wasm', wasmCompiledClassName, '-out', outFile, '-log', 'debug'
}
}
// Rust example helpers
ext.rustBuildRelease = true
task rustToWasm(type: Exec) {
if (rustBuildRelease) {
commandLine 'cargo', 'build', '--release'
} else {
commandLine 'cargo', 'build'
}
}
ext.rustWasmFileName = { ->
def buildType = rustBuildRelease ? 'release' : 'debug'
def wasmFiles = fileTree(dir: "target/wasm32-unknown-unknown/$buildType", includes: ['*.wasm']).files
if (wasmFiles.size() != 1) throw new GradleException('Expected single WASM file, got ' + wasmFiles.size())
return wasmFiles.iterator().next()
}
task rustWasmFile() {
dependsOn rustToWasm
doFirst {
println 'File: ' + rustWasmFileName()
}
}
task showRustWast(type: JavaExec) {
dependsOn rustToWasm
classpath configurations.compileClasspath
main = 'asmble.cli.MainKt'
doFirst {
args 'translate', rustWasmFileName()
}
}
task compileRustWasm(type: JavaExec) {
dependsOn rustToWasm
classpath configurations.compileClasspath
main = 'asmble.cli.MainKt'
doFirst {
// args 'help', 'compile'
def outFile = 'build/wasm-classes/' + wasmCompiledClassName.replace('.', '/') + '.class'
file(outFile).parentFile.mkdirs()
args 'compile', rustWasmFileName(), wasmCompiledClassName, '-out', outFile
}
}
}
}
//project(':examples:go-simple') {
// apply plugin: 'application'
// ext.wasmCompiledClassName = 'asmble.generated.GoSimple'
// dependencies {
// compile files('build/wasm-classes')
// }
// compileJava {
// dependsOn compileGoWasm
// }
// mainClassName = 'asmble.examples.gosimple.Main'
//}
// todo temporary disable Rust regex, because some strings in wasm code exceed the size in 65353 bytes.
// project(':examples:rust-regex') {
// apply plugin: 'application'
// apply plugin: 'me.champeau.gradle.jmh'
// ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
// dependencies {
// compile files('build/wasm-classes')
// testCompile 'junit:junit:4.12'
// }
// compileJava {
// dependsOn compileRustWasm
// }
// mainClassName = 'asmble.examples.rustregex.Main'
// test {
// testLogging.showStandardStreams = true
// testLogging.events 'PASSED', 'SKIPPED'
// }
// jmh {
// iterations = 5
// warmupIterations = 5
// fork = 3
// }
// }
project(':examples:rust-simple') {
apply plugin: 'application'
ext.wasmCompiledClassName = 'asmble.generated.RustSimple'
dependencies {
compile files('build/wasm-classes')
}
compileJava {
dependsOn compileRustWasm
}
mainClassName = 'asmble.examples.rustsimple.Main'
}
project(':examples:rust-string') {
apply plugin: 'application'
ext.wasmCompiledClassName = 'asmble.generated.RustString'
dependencies {
compile files('build/wasm-classes')
}
compileJava {
dependsOn compileRustWasm
}
mainClassName = 'asmble.examples.ruststring.Main'
}
def publishSettings(project, projectName, projectDescription) {
project.with {
apply plugin: 'com.jfrog.bintray'
apply plugin: 'maven-publish'
apply plugin: 'maven'
task sourcesJar(type: Jar) {
from sourceSets.main.allJava
classifier = 'sources'
}
publishing {
publications {
MyPublication(MavenPublication) {
from components.java
groupId group
artifactId projectName
artifact sourcesJar
version version
}
}
}
bintray {
user = bintrayUser
key = bintrayKey
publications = ['MyPublication']
//[Default: false] Whether to override version artifacts already published
override = false
//[Default: false] Whether version should be auto published after an upload
publish = true
pkg {
repo = 'releases'
name = projectName
userOrg = 'fluencelabs'
licenses = ['MIT']
vcsUrl = 'https://github.com/fluencelabs/asmble'
version {
name = project.version
desc = projectDescription
released = new Date()
vcsTag = project.version
}
}
}
}
}

View File

@ -3,7 +3,19 @@ package asmble.ast
import java.util.*
import kotlin.reflect.KClass
/**
* All WebAssembly AST nodes as static inner classes.
*/
sealed class Node {
/**
* Wasm module definition.
*
* The unit of WebAssembly code is the module. A module collects definitions
* for types, functions, tables, memories, and globals. In addition, it can
* declare imports and exports and provide initialization logic in the form
* of data and element segments or a start function.
*/
data class Module(
val types: List<Type.Func> = emptyList(),
val imports: List<Import> = emptyList(),
@ -15,6 +27,7 @@ sealed class Node {
val elems: List<Elem> = emptyList(),
val funcs: List<Func> = emptyList(),
val data: List<Data> = emptyList(),
val names: NameSection? = null,
val customSections: List<CustomSection> = emptyList()
) : Node()
@ -148,6 +161,12 @@ sealed class Node {
}
}
data class NameSection(
val moduleName: String?,
val funcNames: Map<Int, String>,
val localNames: Map<Int, Map<Int, String>>
) : Node()
sealed class Instr : Node() {
fun op() = InstrOp.classToOpMap[this::class] ?: throw Exception("No op found for ${this::class}")
@ -165,12 +184,15 @@ sealed class Node {
interface Const<out T : Number> : Args { val value: T }
}
// Control flow
// Control instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#control-instructions]
object Unreachable : Instr(), Args.None
object Nop : Instr(), Args.None
data class Block(override val type: Type.Value?) : Instr(), Args.Type
data class Loop(override val type: Type.Value?) : Instr(), Args.Type
data class If(override val type: Type.Value?) : Instr(), Args.Type
object Else : Instr(), Args.None
object End : Instr(), Args.None
data class Br(override val relativeDepth: Int) : Instr(), Args.RelativeDepth
@ -181,25 +203,27 @@ sealed class Node {
) : Instr(), Args.Table
object Return : Instr()
// Call operators
data class Call(override val index: Int) : Instr(), Args.Index
data class CallIndirect(
override val index: Int,
override val reserved: Boolean
) : Instr(), Args.ReservedIndex
// Parametric operators
// Parametric instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#parametric-instructions]
object Drop : Instr(), Args.None
object Select : Instr(), Args.None
// Variable access
// Variable instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#variable-instructions]
data class GetLocal(override val index: Int) : Instr(), Args.Index
data class SetLocal(override val index: Int) : Instr(), Args.Index
data class TeeLocal(override val index: Int) : Instr(), Args.Index
data class GetGlobal(override val index: Int) : Instr(), Args.Index
data class SetGlobal(override val index: Int) : Instr(), Args.Index
// Memory operators
// Memory instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#memory-instructions]
data class I32Load(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class I64Load(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class F32Load(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
@ -223,10 +247,12 @@ sealed class Node {
data class I64Store8(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class I64Store16(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class I64Store32(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class CurrentMemory(override val reserved: Boolean) : Instr(), Args.Reserved
data class GrowMemory(override val reserved: Boolean) : Instr(), Args.Reserved
data class MemorySize(override val reserved: Boolean) : Instr(), Args.Reserved
data class MemoryGrow(override val reserved: Boolean) : Instr(), Args.Reserved
// Constants
// Numeric instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#numeric-instructions]
// Constants operators
data class I32Const(override val value: Int) : Instr(), Args.Const<Int>
data class I64Const(override val value: Long) : Instr(), Args.Const<Long>
data class F32Const(override val value: Float) : Instr(), Args.Const<Float>
@ -433,14 +459,20 @@ sealed class Node {
companion object {
// TODO: why can't I set a val in init?
var strToOpMap = emptyMap<String, InstrOp<*>>(); private set
var classToOpMap = emptyMap<KClass<out Instr>, InstrOp<*>>(); private set
var strToOpcodeMap = emptyMap<String, Short>(); private set
var opcodeToStrMap = emptyMap<Short, String>(); private set
val strToOpMap: Map<String, InstrOp<*>>
val classToOpMap: Map<KClass<out Instr>, InstrOp<*>>
val strToOpcodeMap: Map<String, Short>
val opcodeToStrMap: Map<Short, String>
fun op(opcode: Short) = opcodeToStrMap[opcode]?.let(strToOpMap::get) ?: error("No opcode found: $opcode")
init {
// Local vars, set to vals later
var strToOpMap = emptyMap<String, InstrOp<*>>()
var classToOpMap = emptyMap<KClass<out Instr>, InstrOp<*>>()
var strToOpcodeMap = emptyMap<String, Short>()
var opcodeToStrMap = emptyMap<Short, String>()
// Can't use reification here because inline funcs not allowed in nested context :-(
fun <T> opMapEntry(
name: String,
@ -505,8 +537,8 @@ sealed class Node {
opMapEntry("i64.store8", 0x3c, ::MemOpAlignOffsetArg, Instr::I64Store8, Instr.I64Store8::class)
opMapEntry("i64.store16", 0x3d, ::MemOpAlignOffsetArg, Instr::I64Store16, Instr.I64Store16::class)
opMapEntry("i64.store32", 0x3e, ::MemOpAlignOffsetArg, Instr::I64Store32, Instr.I64Store32::class)
opMapEntry("current_memory", 0x3f, ::MemOpReservedArg, Instr::CurrentMemory, Instr.CurrentMemory::class)
opMapEntry("grow_memory", 0x40, ::MemOpReservedArg, Instr::GrowMemory, Instr.GrowMemory::class)
opMapEntry("memory.size", 0x3f, ::MemOpReservedArg, Instr::MemorySize, Instr.MemorySize::class)
opMapEntry("memory.grow", 0x40, ::MemOpReservedArg, Instr::MemoryGrow, Instr.MemoryGrow::class)
opMapEntry("i32.const", 0x41, ::ConstOpIntArg, Instr::I32Const, Instr.I32Const::class)
opMapEntry("i64.const", 0x42, ::ConstOpLongArg, Instr::I64Const, Instr.I64Const::class)
@ -639,6 +671,11 @@ sealed class Node {
opMapEntry("i64.reinterpret/f64", 0xbd, ReinterpretOp::NoArg, Instr.I64ReinterpretF64, Instr.I64ReinterpretF64::class)
opMapEntry("f32.reinterpret/i32", 0xbe, ReinterpretOp::NoArg, Instr.F32ReinterpretI32, Instr.F32ReinterpretI32::class)
opMapEntry("f64.reinterpret/i64", 0xbf, ReinterpretOp::NoArg, Instr.F64ReinterpretI64, Instr.F64ReinterpretI64::class)
this.strToOpMap = strToOpMap
this.classToOpMap = classToOpMap
this.strToOpcodeMap = strToOpcodeMap
this.opcodeToStrMap = opcodeToStrMap
}
}
}

View File

@ -2,13 +2,24 @@ package asmble.ast
import asmble.io.SExprToStr
/**
* Ast representation of wasm S-expressions (wast format).
* see [[https://webassembly.github.io/spec/core/text/index.html]]
*/
sealed class SExpr {
data class Multi(val vals: List<SExpr> = emptyList()) : SExpr() {
override fun toString() = SExprToStr.Compact.fromSExpr(this)
}
data class Symbol(val contents: String = "", val quoted: Boolean = false) : SExpr() {
data class Symbol(
val contents: String = "",
val quoted: Boolean = false,
val hasNonUtf8ByteSeqs: Boolean = false
) : SExpr() {
override fun toString() = SExprToStr.Compact.fromSExpr(this)
// This is basically the same as the deprecated java.lang.String#getBytes
fun rawContentCharsToBytes() = contents.toCharArray().map(Char::toByte)
}
}

View File

@ -1,6 +1,10 @@
package asmble.ast
/**
* Ast representation of wasm script.
*/
data class Script(val commands: List<Cmd>) {
sealed class Cmd {
data class Module(val module: Node.Module, val name: String?): Cmd()
data class Register(val string: String, val name: String?): Cmd()

View File

@ -0,0 +1,250 @@
package asmble.ast
// This is a utility for walking the stack. It can do validation or just walk naively.
data class Stack(
// If some of these values below are null, the pops/pushes may appear "unknown"
val mod: CachedModule? = null,
val func: Node.Func? = null,
// Null if not tracking the current stack and all pops succeed
val current: List<Node.Type.Value>? = null,
val insnApplies: List<InsnApply> = emptyList(),
val strict: Boolean = false,
val unreachableUntilNextEndCount: Int = 0
) {
fun next(v: Node.Instr, callFuncTypeOverride: Node.Type.Func? = null) = insnApply(v) {
// If we're unreachable, and not an end, we skip and move on
if (unreachableUntilNextEndCount > 0 && v !is Node.Instr.End) {
// If it's a block, we increase it because we'll see another end
return@insnApply if (v is Node.Instr.Args.Type) unreachable(unreachableUntilNextEndCount + 1) else nop()
}
when (v) {
is Node.Instr.Nop, is Node.Instr.Block, is Node.Instr.Loop -> nop()
is Node.Instr.If, is Node.Instr.BrIf -> popI32()
is Node.Instr.Return -> (func?.type?.ret?.let { pop(it) } ?: nop()) + unreachable(1)
is Node.Instr.Unreachable -> unreachable(1)
is Node.Instr.End, is Node.Instr.Else -> {
// Put back what was before the last block and add the block's type
// Go backwards to find the starting block
var currDepth = 0
val found = insnApplies.findLast {
when (it.insn) {
is Node.Instr.End -> { currDepth++; false }
is Node.Instr.Args.Type -> if (currDepth > 0) { currDepth--; false } else true
else -> false
}
}?.takeIf {
// When it's else, needs to be if
v !is Node.Instr.Else || it.insn is Node.Instr.If
}
val changes = when {
found != null && found.insn is Node.Instr.Args.Type &&
found.stackAtBeginning != null && this != null -> {
// Pop everything from before the block's start, then push if necessary...
// The If block includes an int at the beginning we must not include when subtracting
var preBlockStackSize = found.stackAtBeginning.size
if (found.insn is Node.Instr.If) preBlockStackSize--
val popped =
if (unreachableUntilNextEndCount > 1) nop()
else (0 until (size - preBlockStackSize)).flatMap { pop() }
// Only push if this is not an else
val pushed =
if (unreachableUntilNextEndCount > 1 || v is Node.Instr.Else) nop()
else (found.insn.type?.let { push(it) } ?: nop())
popped + pushed
}
strict -> error("Unable to find starting block for end")
else -> nop()
}
if (unreachableUntilNextEndCount > 0) changes + unreachable(unreachableUntilNextEndCount - 1)
else changes
}
is Node.Instr.Br -> unreachable(v.relativeDepth + 1)
is Node.Instr.BrTable -> popI32() + unreachable(1)
is Node.Instr.Call -> (callFuncTypeOverride ?: func(v.index)).let {
if (it == null) error("Call func type missing")
it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.CallIndirect -> (callFuncTypeOverride ?: mod?.mod?.types?.getOrNull(v.index)).let {
if (it == null) error("Call func type missing")
// We add one for the table index
popI32() + it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.Drop -> pop()
is Node.Instr.Select -> popI32() + pop().let { it + pop(it.first().type) + push(it.first().type) }
is Node.Instr.GetLocal -> push(local(v.index))
is Node.Instr.SetLocal -> pop(local(v.index))
is Node.Instr.TeeLocal -> local(v.index).let { pop(it) + push(it) }
is Node.Instr.GetGlobal -> push(global(v.index))
is Node.Instr.SetGlobal -> pop(global(v.index))
is Node.Instr.I32Load, is Node.Instr.I32Load8S, is Node.Instr.I32Load8U,
is Node.Instr.I32Load16U, is Node.Instr.I32Load16S -> popI32() + pushI32()
is Node.Instr.I64Load, is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U,
is Node.Instr.I64Load16S, is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> popI32() + pushI64()
is Node.Instr.F32Load -> popI32() + pushF32()
is Node.Instr.F64Load -> popI32() + pushF64()
is Node.Instr.I32Store, is Node.Instr.I32Store8, is Node.Instr.I32Store16 -> popI32() + popI32()
is Node.Instr.I64Store, is Node.Instr.I64Store8,
is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI64() + popI32()
is Node.Instr.F32Store -> popF32() + popI32()
is Node.Instr.F64Store -> popF64() + popI32()
is Node.Instr.MemorySize -> pushI32()
is Node.Instr.MemoryGrow -> popI32() + pushI32()
is Node.Instr.I32Const -> pushI32()
is Node.Instr.I64Const -> pushI64()
is Node.Instr.F32Const -> pushF32()
is Node.Instr.F64Const -> pushF64()
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
is Node.Instr.I32DivU, is Node.Instr.I32RemS, is Node.Instr.I32RemU, is Node.Instr.I32And,
is Node.Instr.I32Or, is Node.Instr.I32Xor, is Node.Instr.I32Shl, is Node.Instr.I32ShrS,
is Node.Instr.I32ShrU, is Node.Instr.I32Rotl, is Node.Instr.I32Rotr, is Node.Instr.I32Eq,
is Node.Instr.I32Ne, is Node.Instr.I32LtS, is Node.Instr.I32LeS, is Node.Instr.I32LtU,
is Node.Instr.I32LeU, is Node.Instr.I32GtS, is Node.Instr.I32GeS, is Node.Instr.I32GtU,
is Node.Instr.I32GeU -> popI32() + popI32() + pushI32()
is Node.Instr.I32Clz, is Node.Instr.I32Ctz, is Node.Instr.I32Popcnt,
is Node.Instr.I32Eqz -> popI32() + pushI32()
is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS,
is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And,
is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS,
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr -> popI64() + popI64() + pushI64()
is Node.Instr.I64Eq, is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS,
is Node.Instr.I64LtU, is Node.Instr.I64LeU, is Node.Instr.I64GtS,
is Node.Instr.I64GeS, is Node.Instr.I64GtU, is Node.Instr.I64GeU -> popI64() + popI64() + pushI32()
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt -> popI64() + pushI64()
is Node.Instr.I64Eqz -> popI64() + pushI32()
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
is Node.Instr.F32Min, is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32()
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge -> popF32() + popF32() + pushI32()
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor,
is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> popF32() + pushF32()
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
is Node.Instr.F64Min, is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64()
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge -> popF64() + popF64() + pushI32()
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor,
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + pushF64()
is Node.Instr.I32WrapI64 -> popI64() + pushI32()
is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
is Node.Instr.I32ReinterpretF32 -> popF32() + pushI32()
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popF64() + pushI32()
is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI32() + pushI64()
is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popF32() + pushI64()
is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64,
is Node.Instr.I64ReinterpretF64 -> popF64() + pushI64()
is Node.Instr.F32ConvertSI32, is Node.Instr.F32ConvertUI32,
is Node.Instr.F32ReinterpretI32 -> popI32() + pushF32()
is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64 -> popI64() + pushF32()
is Node.Instr.F32DemoteF64 -> popF64() + pushF32()
is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popI32() + pushF64()
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64,
is Node.Instr.F64ReinterpretI64 -> popI64() + pushF64()
is Node.Instr.F64PromoteF32 -> popF32() + pushF64()
}
}
protected fun insnApply(v: Node.Instr, fn: MutableList<Node.Type.Value>?.() -> List<InsnApplyResponse>): Stack {
val mutStack = current?.toMutableList()
val applyResp = mutStack.fn()
val newUnreachable = (applyResp.find { it is Unreachable } as? Unreachable)?.untilEndCount
return copy(
current = mutStack,
insnApplies = insnApplies + InsnApply(
insn = v,
stackAtBeginning = current,
stackChanges = applyResp.mapNotNull { it as? StackChange },
unreachableUntilEndCount = newUnreachable ?: unreachableUntilNextEndCount
),
unreachableUntilNextEndCount = newUnreachable ?: unreachableUntilNextEndCount
)
}
protected fun unreachable(untilEndCount: Int) = listOf(Unreachable(untilEndCount))
protected fun local(index: Int) = func?.let {
it.type.params.getOrNull(index) ?: it.locals.getOrNull(index - it.type.params.size)
}
protected fun global(index: Int) = mod?.let {
it.importGlobals.getOrNull(index)?.type?.contentType ?:
it.mod.globals.getOrNull(index - it.importGlobals.size)?.type?.contentType
}
protected fun func(index: Int) = mod?.let {
it.importFuncs.getOrNull(index)?.typeIndex?.let { i -> it.mod.types.getOrNull(i) } ?:
it.mod.funcs.getOrNull(index - it.importFuncs.size)?.type
}
protected fun nop() = emptyList<StackChange>()
protected fun MutableList<Node.Type.Value>?.popType(expecting: Node.Type.Value? = null) =
this?.takeIf {
it.isNotEmpty().also {
require(!strict || it) { "Expected $expecting got empty" }
}
}?.let {
removeAt(size - 1).takeIf { actual -> (expecting == null || actual == expecting).also {
require(!strict || it) { "Expected $expecting got $actual" }
} }
} ?: expecting
protected fun MutableList<Node.Type.Value>?.pop(expecting: Node.Type.Value? = null) =
listOf(StackChange(popType(expecting), true))
protected fun MutableList<Node.Type.Value>?.popI32() = pop(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.popI64() = pop(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.popF32() = pop(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.popF64() = pop(Node.Type.Value.F64)
protected fun MutableList<Node.Type.Value>?.push(type: Node.Type.Value? = null) =
listOf(StackChange(type, false)).also { if (this != null && type != null) add(type) }
protected fun MutableList<Node.Type.Value>?.pushI32() = push(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.pushI64() = push(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.pushF32() = push(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.pushF64() = push(Node.Type.Value.F64)
data class InsnApply(
val insn: Node.Instr,
val stackAtBeginning: List<Node.Type.Value>?,
val stackChanges: List<StackChange>,
val unreachableUntilEndCount: Int
)
protected interface InsnApplyResponse
data class StackChange(
val type: Node.Type.Value?,
val pop: Boolean
) : InsnApplyResponse
data class Unreachable(
val untilEndCount: Int
) : InsnApplyResponse
class CachedModule(val mod: Node.Module) {
val importFuncs by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Func } }
val importGlobals by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Global } }
}
companion object {
fun walkStrict(mod: Node.Module, func: Node.Func, afterInsn: ((Stack, Node.Instr) -> Unit)? = null) =
func.instructions.fold(Stack(
mod = CachedModule(mod),
func = func,
current = emptyList(),
strict = true
)) { stack, insn -> stack.next(insn).also { afterInsn?.invoke(it, insn) } }.also { stack ->
// We expect to be in an unreachable state at the end or have the single return value on the stack
if (stack.unreachableUntilNextEndCount == 0) {
val expectedStack = (func.type.ret?.let { listOf(it) } ?: emptyList())
require(expectedStack == stack.current) {
"Expected end to be $expectedStack, got ${stack.current}"
}
}
}
fun stackChanges(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
Stack().next(v, callFuncType).insnApplies.last().stackChanges
fun stackChanges(mod: CachedModule, func: Node.Func, v: Node.Instr) =
Stack(mod, func).next(v).insnApplies.last().stackChanges
fun stackDiff(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
stackChanges(v, callFuncType).sumBy { if (it.pop) -1 else 1 }
fun stackDiff(mod: CachedModule, func: Node.Func, v: Node.Instr) =
stackChanges(mod, func, v).sumBy { if (it.pop) -1 else 1 }
}
}

View File

@ -0,0 +1,183 @@
package asmble.ast.opt
import asmble.ast.Node
import asmble.ast.Stack
// This is a naive implementation that just grabs adjacent sets of restricted insns and breaks the one that will save
// the most instructions off into its own function.
open class SplitLargeFunc(
val minSetLength: Int = 5,
val maxSetLength: Int = 40,
val maxParamCount: Int = 30
) {
// Null if no replacement. Second value is number of instructions saved. fnIndex must map to actual func,
// not imported one.
fun apply(mod: Node.Module, fnIndex: Int): Pair<Node.Module, Int>? {
// Get the func
val importFuncCount = mod.imports.count { it.kind is Node.Import.Kind.Func }
val actualFnIndex = fnIndex - importFuncCount
val func = mod.funcs.getOrElse(actualFnIndex) {
error("Unable to find non-import func at $fnIndex (actual $actualFnIndex)")
}
// Just take the best pattern and apply it
val newFuncIndex = importFuncCount + mod.funcs.size
return commonPatterns(mod, func).firstOrNull()?.let { pattern ->
// Name it as <funcname>$splitN (n is num just to disambiguate) if names are part of the mod
val newName = mod.names?.funcNames?.get(fnIndex)?.let {
"$it\$split".let { it + mod.names.funcNames.count { (_, v) -> v.startsWith(it) } }
}
// Go over every replacement in reverse, changing the instructions to our new set
val newInsns = pattern.replacements.foldRight(func.instructions) { repl, insns ->
insns.take(repl.range.start) +
repl.preCallConsts +
Node.Instr.Call(newFuncIndex) +
insns.drop(repl.range.endInclusive + 1)
}
// Return the module w/ the new function, it's new name, and the insns saved
mod.copy(
funcs = mod.funcs.toMutableList().also {
it[actualFnIndex] = func.copy(instructions = newInsns)
} + pattern.newFunc,
names = mod.names?.copy(funcNames = mod.names.funcNames.toMutableMap().also {
it[newFuncIndex] = newName!!
})
) to pattern.insnsSaved
}
}
// Results are by most insns saved. There can be overlap across patterns but never within a single pattern.
fun commonPatterns(mod: Node.Module, fn: Node.Func): List<CommonPattern> {
// Walk the stack for validation needs
val stack = Stack.walkStrict(mod, fn)
// Let's grab sets of insns that qualify. In this naive impl, in order to qualify the insn set needs to
// only have a certain set of insns that can be broken off. It can also only change the stack by 0 or 1
// value while never dipping below the starting stack. We also store the index they started at.
var insnSets = emptyList<InsnSet>()
// Pair in fold keyed by insn index
fn.instructions.foldIndexed(null as List<Pair<Int, Node.Instr>>?) { index, lastInsns, insn ->
if (!insn.canBeMoved) null else (lastInsns ?: emptyList()).plus(index to insn).also { fullNewInsnSet ->
// Get all final instructions between min and max size and with allowed param count (i.e. const count)
val trailingInsnSet = fullNewInsnSet.takeLast(maxSetLength)
// Get all instructions between the min and max
insnSets += (minSetLength..maxSetLength).
asSequence().
flatMap { trailingInsnSet.asSequence().windowed(it) }.
filter { it.count { it.second is Node.Instr.Args.Const<*> } <= maxParamCount }.
mapNotNull { newIndexedInsnSet ->
// Before adding, make sure it qualifies with the stack
InsnSet(
startIndex = newIndexedInsnSet.first().first,
insns = newIndexedInsnSet.map { it.second },
valueAddedToStack = null
).withStackValueIfValid(stack)
}
}
}
// Sort the insn sets by the ones with the most insns
insnSets = insnSets.sortedByDescending { it.insns.size }
// Now let's create replacements for each, keyed by the extracted func
val patterns = insnSets.fold(emptyMap<Node.Func, List<Replacement>>()) { map, insnSet ->
insnSet.extractCommonFunc().let { (func, replacement) ->
val existingReplacements = map.getOrDefault(func, emptyList())
// Ignore if there is any overlap
if (existingReplacements.any(replacement::overlaps)) map
else map + (func to existingReplacements.plus(replacement))
}
}
// Now sort the patterns by most insns saved and return
return patterns.map { (k, v) ->
CommonPattern(k, v.sortedBy { it.range.first })
}.sortedByDescending { it.insnsSaved }
}
val Node.Instr.canBeMoved get() =
// No blocks
this !is Node.Instr.Block && this !is Node.Instr.Loop && this !is Node.Instr.If &&
this !is Node.Instr.Else && this !is Node.Instr.End &&
// No breaks
this !is Node.Instr.Br && this !is Node.Instr.BrIf && this !is Node.Instr.BrTable &&
// No return
this !is Node.Instr.Return &&
// No local access
this !is Node.Instr.GetLocal && this !is Node.Instr.SetLocal && this !is Node.Instr.TeeLocal
fun InsnSet.withStackValueIfValid(stack: Stack): InsnSet? {
// This makes sure that the stack only changes by at most one item and never dips below its starting val.
// If it is invalid, null is returned. If it qualifies and does change 1 value, it is set.
// First, make sure the stack after the last insn is the same as the first or the same + 1 val
val startingStack = stack.insnApplies[startIndex].stackAtBeginning!!
val endingStack = stack.insnApplies.getOrNull(startIndex + insns.size)?.stackAtBeginning ?: stack.current!!
if (endingStack.size != startingStack.size && endingStack.size != startingStack.size + 1) return null
if (endingStack.take(startingStack.size) != startingStack) return null
// Now, walk the insns and make sure they never pop below the start
var stackCounter = 0
stack.insnApplies.subList(startIndex, startIndex + insns.size).forEach {
it.stackChanges.forEach {
stackCounter += if (it.pop) -1 else 1
if (stackCounter < 0) return null
}
}
// We're good, now only if the ending stack is one over the start do we have a ret val
return copy(
valueAddedToStack = endingStack.lastOrNull()?.takeIf { endingStack.size == startingStack.size + 1 }
)
}
fun InsnSet.extractCommonFunc() =
// This extracts a function with constants changed to parameters
insns.fold(Pair(
Node.Func(Node.Type.Func(params = emptyList(), ret = valueAddedToStack), emptyList(), emptyList()),
Replacement(range = startIndex until startIndex + insns.size, preCallConsts = emptyList()))
) { (func, repl), insn ->
if (insn !is Node.Instr.Args.Const<*>) func.copy(instructions = func.instructions + insn) to repl
else func.copy(
type = func.type.copy(params = func.type.params + insn.constType),
instructions = func.instructions + Node.Instr.GetLocal(func.type.params.size)
) to repl.copy(preCallConsts = repl.preCallConsts + insn)
}
protected val Node.Instr.Args.Const<*>.constType get() = when (this) {
is Node.Instr.I32Const -> Node.Type.Value.I32
is Node.Instr.I64Const -> Node.Type.Value.I64
is Node.Instr.F32Const -> Node.Type.Value.F32
is Node.Instr.F64Const -> Node.Type.Value.F64
else -> error("unreachable")
}
data class InsnSet(
val startIndex: Int,
val insns: List<Node.Instr>,
val valueAddedToStack: Node.Type.Value?
)
data class Replacement(
val range: IntRange,
val preCallConsts: List<Node.Instr>
) {
// Subtract one because there is a call after this
val insnsSaved get() = (range.last + 1) - range.first - 1 - preCallConsts.size
fun overlaps(o: Replacement) = range.contains(o.range.first) || range.contains(o.range.last) ||
o.range.contains(range.first) || o.range.contains(range.last)
}
data class CommonPattern(
val newFunc: Node.Func,
// In order by earliest replacement first
val replacements: List<Replacement>
) {
// Replacement pieces saved (with one added for the invocation) less new func instructions
val insnsSaved get() = replacements.sumBy { it.insnsSaved } - newFunc.instructions.size
}
companion object : SplitLargeFunc()
}

View File

@ -4,9 +4,9 @@ import asmble.ast.Script
import asmble.compile.jvm.AstToAsm
import asmble.compile.jvm.ClsContext
import asmble.compile.jvm.withComputedFramesAndMaxs
import org.objectweb.asm.ClassWriter
import java.io.FileOutputStream
@Suppress("NAME_SHADOWING")
open class Compile : Command<Compile.Args>() {
override val name = "compile"
@ -32,6 +32,17 @@ open class Compile : Command<Compile.Args>() {
opt = "out",
desc = "The file name to output to. Can be '--' to write to stdout.",
default = "<outClass.class>"
),
name = bld.arg(
name = "name",
opt = "name",
desc = "The name to use for this module. Will override the name on the module if present.",
default = "<name on module or none>"
).takeIf { it != "<name on module or none>" },
includeBinary = bld.flag(
opt = "bindata",
desc = "Embed the WASM binary as an annotation on the class.",
lowPriority = true
)
).also { bld.done() }
@ -40,8 +51,8 @@ open class Compile : Command<Compile.Args>() {
val inFormat =
if (args.inFormat != "<use file extension>") args.inFormat
else args.inFile.substringAfterLast('.', "<unknown>")
val script = Translate.inToAst(args.inFile, inFormat)
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
val script = Translate().also { it.logger = logger }.inToAst(args.inFile, inFormat)
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module) ?:
error("Only a single sexpr for (module) allowed")
val outStream = when (args.outFile) {
"<outClass.class>" -> FileOutputStream(args.outClass.substringAfterLast('.') + ".class")
@ -52,8 +63,10 @@ open class Compile : Command<Compile.Args>() {
val ctx = ClsContext(
packageName = if (!args.outClass.contains('.')) "" else args.outClass.substringBeforeLast('.'),
className = args.outClass.substringAfterLast('.'),
mod = mod,
logger = logger
mod = mod.module,
modName = args.name ?: mod.name,
logger = logger,
includeBinary = args.includeBinary
)
AstToAsm.fromModule(ctx)
outStream.write(ctx.cls.withComputedFramesAndMaxs())
@ -64,7 +77,9 @@ open class Compile : Command<Compile.Args>() {
val inFile: String,
val inFormat: String,
val outClass: String,
val outFile: String
val outFile: String,
val name: String?,
val includeBinary: Boolean
)
companion object : Compile()

View File

@ -1,7 +1,11 @@
package asmble.cli
import asmble.compile.jvm.javaIdent
import asmble.run.jvm.Module
/**
* This class provide ''invoke'' WASM code functionality.
*/
open class Invoke : ScriptCommand<Invoke.Args>() {
override val name = "invoke"
@ -33,17 +37,19 @@ open class Invoke : ScriptCommand<Invoke.Args>() {
).also { bld.done() }
override fun run(args: Args) {
// Compiles wasm to bytecode, do registrations and so on.
val ctx = prepareContext(args.scriptArgs)
// Instantiate the module
val module =
if (args.module == "<last-in-entry>") ctx.modules.lastOrNull() ?: error("No modules available")
else ctx.registrations[args.module] ?: error("Unable to find module registered as ${args.module}")
else ctx.registrations[args.module] as? Module.Instance ?:
error("Unable to find module registered as ${args.module}")
// Just make sure the module is instantiated here...
module.instance(ctx)
val instance = module.instance(ctx)
// If an export is provided, call it
if (args.export != "<start-func>") args.export.javaIdent.let { javaName ->
val method = module.cls.declaredMethods.find { it.name == javaName } ?:
error("Unable to find export '${args.export}'")
// Finds java method(wasm fn) in class(wasm module) by name(declared in <start-func>)
val method = module.cls.declaredMethods.find { it.name == javaName } ?: error("Unable to find export '${args.export}'")
// Map args to params
require(method.parameterTypes.size == args.args.size) {
"Given arg count of ${args.args.size} is invalid for $method"
@ -57,11 +63,20 @@ open class Invoke : ScriptCommand<Invoke.Args>() {
else -> error("Unrecognized type for param ${index + 1}: $paramType")
}
}
val result = method.invoke(module.instance(ctx), *params.toTypedArray())
val result = method.invoke(instance, *params.toTypedArray())
if (args.resultToStdout && method.returnType != Void.TYPE) println(result)
}
}
/**
* Arguments for 'invoke' command.
*
* @param scriptArgs Common arguments for 'invoke' and 'run' ScriptCommands.
* @param module The module name to run. If it's a JVM class, it must have a no-arg constructor
* @param export The specific export function to invoke
* @param args Parameter for the export if export is present
* @param resultToStdout If true result will print to stout
*/
data class Args(
val scriptArgs: ScriptCommand.ScriptArgs,
val module: String,

View File

@ -0,0 +1,67 @@
package asmble.cli
import asmble.compile.jvm.Linker
import asmble.compile.jvm.withComputedFramesAndMaxs
import java.io.FileOutputStream
open class Link : Command<Link.Args>() {
override val name = "link"
override val desc = "Link WebAssembly modules in a single class file. TODO: not done"
override fun args(bld: Command.ArgsBuilder) = Args(
outFile = bld.arg(
name = "outFile",
opt = "out",
desc = "The file name to output to. Can be '--' to write to stdout.",
default = "<outClass.class>"
),
modules = bld.args(
name = "modules",
desc = "The fully qualified class name of the modules on the classpath to link. A module name can be" +
" added after an equals sign to set/override the existing module name."
),
outClass = bld.arg(
name = "outClass",
desc = "The fully qualified class name."
),
defaultMaxMem = bld.arg(
name = "defaultMaxMem",
opt = "maxmem",
desc = "The max number of pages to build memory with when not specified by the module/import.",
default = "10"
).toInt()
).also { bld.done() }
override fun run(args: Args) {
val outStream = when (args.outFile) {
"<outClass.class>" -> FileOutputStream(args.outClass.substringAfterLast('.') + ".class")
"--" -> System.out
else -> FileOutputStream(args.outFile)
}
outStream.use { outStream ->
val ctx = Linker.Context(
classes = args.modules.map { module ->
val pieces = module.split('=', limit = 2)
Linker.ModuleClass(
cls = Class.forName(pieces.first()),
overrideName = pieces.getOrNull(1)
)
},
className = args.outClass,
defaultMaxMemPages = args.defaultMaxMem
)
Linker.link(ctx)
outStream.write(ctx.cls.withComputedFramesAndMaxs())
}
}
data class Args(
val modules: List<String>,
val outClass: String,
val outFile: String,
val defaultMaxMem: Int
)
companion object : Link()
}

View File

@ -3,8 +3,11 @@ package asmble.cli
import asmble.util.Logger
import kotlin.system.exitProcess
val commands = listOf(Compile, Help, Invoke, Run, Translate)
val commands = listOf(Compile, Help, Invoke, Link, Run, SplitFunc, Translate)
/**
* Entry point of command line interface.
*/
fun main(args: Array<String>) {
if (args.isEmpty()) return println(
"""
@ -28,6 +31,7 @@ fun main(args: Array<String>) {
val globals = Main.globalArgs(argBuild)
logger = Logger.Print(globals.logLevel)
command.logger = logger
logger.info { "Running the command=${command.name} with args=${argBuild.args}" }
command.runWithArgs(argBuild)
} catch (e: Exception) {
logger.error { "Error ${command?.let { "in command '${it.name}'" } ?: ""}: ${e.message}" }

View File

@ -2,6 +2,7 @@ package asmble.cli
import asmble.ast.Script
import asmble.compile.jvm.javaIdent
import asmble.run.jvm.Module
import asmble.run.jvm.ScriptContext
import java.io.File
import java.util.*
@ -44,16 +45,21 @@ abstract class ScriptCommand<T> : Command<T>() {
)
fun prepareContext(args: ScriptArgs): ScriptContext {
var ctx = ScriptContext(
var context = ScriptContext(
packageName = "asmble.temp" + UUID.randomUUID().toString().replace("-", ""),
defaultMaxMemPages = args.defaultMaxMemPages
)
// Compile everything
ctx = args.inFiles.foldIndexed(ctx) { index, ctx, inFile ->
context = args.inFiles.foldIndexed(context) { index, ctx, inFile ->
try {
when (inFile.substringAfterLast('.')) {
// if input file is class file
"class" -> ctx.classLoader.addClass(File(inFile).readBytes()).let { ctx }
else -> Translate.inToAst(inFile, inFile.substringAfterLast('.')).let { inAst ->
// if input file is wasm file
else -> {
val translateCmd = Translate
translateCmd.logger = this.logger
translateCmd.inToAst(inFile, inFile.substringAfterLast('.')).let { inAst ->
val (mod, name) = (inAst.commands.singleOrNull() as? Script.Cmd.Module) ?:
error("Input file must only contain a single module")
val className = name?.javaIdent?.capitalize() ?:
@ -66,18 +72,29 @@ abstract class ScriptCommand<T> : Command<T>() {
}
}
}
} catch (e: Exception) { throw Exception("Failed loading $inFile - ${e.message}", e) }
}
} catch (e: Exception) {
throw Exception("Failed loading $inFile - ${e.message}", e)
}
}
// Do registrations
ctx = args.registrations.fold(ctx) { ctx, (moduleName, className) ->
val cls = Class.forName(className, true, ctx.classLoader)
ctx.copy(registrations = ctx.registrations +
(moduleName to ScriptContext.NativeModule(cls, cls.newInstance())))
context = args.registrations.fold(context) { ctx, (moduleName, className) ->
ctx.withModuleRegistered(moduleName,
Module.Native(Class.forName(className, true, ctx.classLoader).newInstance()))
}
if (args.specTestRegister) ctx = ctx.withHarnessRegistered()
return ctx
if (args.specTestRegister) context = context.withHarnessRegistered() // проверить что не так с "Cannot find compatible import for spectest::print"
return context
}
/**
* Common arguments for 'invoke' and 'run' ScriptCommands.
*
* @param inFiles Files to add to classpath. Can be wasm, wast, or class file
* @param registrations Register class name to a module name
* @param disableAutoRegister If set, this will not auto-register modules with names
* @param specTestRegister If true, registers the spec test harness as 'spectest'
* @param defaultMaxMemPages The maximum number of memory pages when a module doesn't say
*/
data class ScriptArgs(
val inFiles: List<String>,
val registrations: List<Pair<String, String>>,

View File

@ -0,0 +1,146 @@
package asmble.cli
import asmble.ast.Node
import asmble.ast.Script
import asmble.ast.opt.SplitLargeFunc
open class SplitFunc : Command<SplitFunc.Args>() {
override val name = "split-func"
override val desc = "Split a WebAssembly function into two"
override fun args(bld: Command.ArgsBuilder) = Args(
inFile = bld.arg(
name = "inFile",
desc = "The wast or wasm WebAssembly file name. Can be '--' to read from stdin."
),
funcName = bld.arg(
name = "funcName",
desc = "The name (or '#' + function space index) of the function to split"
),
inFormat = bld.arg(
name = "inFormat",
opt = "in",
desc = "Either 'wast' or 'wasm' to describe format.",
default = "<use file extension>",
lowPriority = true
),
outFile = bld.arg(
name = "outFile",
opt = "outFile",
desc = "The wast or wasm WebAssembly file name. Can be '--' to write to stdout.",
default = "<inFileSansExt.split.wasm or stdout>",
lowPriority = true
),
outFormat = bld.arg(
name = "outFormat",
opt = "out",
desc = "Either 'wast' or 'wasm' to describe format.",
default = "<use file extension or wast for stdout>",
lowPriority = true
),
compact = bld.flag(
opt = "compact",
desc = "If set for wast out format, will be compacted.",
lowPriority = true
),
minInsnSetLength = bld.arg(
name = "minInsnSetLength",
opt = "minLen",
desc = "The minimum number of instructions allowed for the split off function.",
default = "5",
lowPriority = true
).toInt(),
maxInsnSetLength = bld.arg(
name = "maxInsnSetLength",
opt = "maxLen",
desc = "The maximum number of instructions allowed for the split off function.",
default = "40",
lowPriority = true
).toInt(),
maxNewFuncParamCount = bld.arg(
name = "maxNewFuncParamCount",
opt = "maxParams",
desc = "The maximum number of params allowed for the split off function.",
default = "30",
lowPriority = true
).toInt(),
attempts = bld.arg(
name = "attempts",
opt = "attempts",
desc = "The number of attempts to perform.",
default = "1",
lowPriority = true
).toInt()
).also { bld.done() }
override fun run(args: Args) {
// Load the mod
val translate = Translate().also { it.logger = logger }
val inFormat =
if (args.inFormat != "<use file extension>") args.inFormat
else args.inFile.substringAfterLast('.', "<unknown>")
val script = translate.inToAst(args.inFile, inFormat)
var mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?: error("Only a single module allowed")
// Do attempts
val splitter = SplitLargeFunc(
minSetLength = args.minInsnSetLength,
maxSetLength = args.maxInsnSetLength,
maxParamCount = args.maxNewFuncParamCount
)
for (attempt in 0 until args.attempts) {
// Find the function
var index = mod.names?.funcNames?.toList()?.find { it.second == args.funcName }?.first
if (index == null && args.funcName.startsWith('#')) index = args.funcName.drop(1).toInt()
val origFunc = index?.let {
mod.funcs.getOrNull(it - mod.imports.count { it.kind is Node.Import.Kind.Func })
} ?: error("Unable to find func")
// Split it
val results = splitter.apply(mod, index)
if (results == null) {
logger.warn { "No instructions after attempt $attempt" }
break
}
val (splitMod, insnsSaved) = results
val newFunc = splitMod.funcs[index - mod.imports.count { it.kind is Node.Import.Kind.Func }]
val splitFunc = splitMod.funcs.last()
logger.warn {
"Split complete, from func with ${origFunc.instructions.size} insns to a func " +
"with ${newFunc.instructions.size} insns + delegated func " +
"with ${splitFunc.instructions.size} insns and ${splitFunc.type.params.size} params, " +
"saved $insnsSaved insns"
}
mod = splitMod
}
// Write it
val outFile = when {
args.outFile != "<inFileSansExt.split.wasm or stdout>" -> args.outFile
args.inFile == "--" -> "--"
else -> args.inFile.replaceAfterLast('.', "split." + args.inFile.substringAfterLast('.'))
}
val outFormat = when {
args.outFormat != "<use file extension or wast for stdout>" -> args.outFormat
outFile == "--" -> "wast"
else -> outFile.substringAfterLast('.', "<unknown>")
}
translate.astToOut(outFile, outFormat, args.compact,
Script(listOf(Script.Cmd.Module(mod, mod.names?.moduleName))))
}
data class Args(
val inFile: String,
val inFormat: String,
val funcName: String,
val outFile: String,
val outFormat: String,
val compact: Boolean,
val minInsnSetLength: Int,
val maxInsnSetLength: Int,
val maxNewFuncParamCount: Int,
val attempts: Int
)
companion object : SplitFunc()
}

View File

@ -52,13 +52,38 @@ open class Translate : Command<Translate.Args>() {
if (args.outFormat != "<use file extension or wast for stdout>") args.outFormat
else if (args.outFile == "--") "wast"
else args.outFile.substringAfterLast('.', "<unknown>")
astToOut(args.outFile, outFormat, args.compact, script)
}
fun inToAst(inFile: String, inFormat: String): Script {
val inBytes =
if (inFile == "--") System.`in`.use { it.readBytes() }
else File(inFile).let { f ->
// Input file might not fit into the memory
FileInputStream(f).use { it.readBytes(f.length().toIntExact()) }
}
return when (inFormat) {
"wast" -> StrToSExpr.parse(inBytes.toString(Charsets.UTF_8)).let { res ->
when (res) {
is StrToSExpr.ParseResult.Error -> error("Error [${res.pos.line}:${res.pos.char}] - ${res.pos}")
is StrToSExpr.ParseResult.Success -> SExprToAst.toScript(SExpr.Multi(res.vals))
}
}
"wasm" ->
Script(listOf(Script.Cmd.Module(BinaryToAst(logger = this.logger).toModule(
ByteReader.InputStream(inBytes.inputStream())), null)))
else -> error("Unknown in format '$inFormat'")
}
}
fun astToOut(outFile: String, outFormat: String, compact: Boolean, script: Script) {
val outStream =
if (args.outFile == "--") System.out
else FileOutputStream(args.outFile)
if (outFile == "--") System.out
else FileOutputStream(outFile)
outStream.use { outStream ->
when (outFormat) {
"wast" -> {
val sexprToStr = if (args.compact) SExprToStr.Compact else SExprToStr
val sexprToStr = if (compact) SExprToStr.Compact else SExprToStr
val sexprs = AstToSExpr.fromScript(script)
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
}
@ -72,24 +97,6 @@ open class Translate : Command<Translate.Args>() {
}
}
fun inToAst(inFile: String, inFormat: String): Script {
val inBytes =
if (inFile == "--") System.`in`.use { it.readBytes() }
else File(inFile).let { f -> FileInputStream(f).use { it.readBytes(f.length().toIntExact()) } }
return when (inFormat) {
"wast" -> StrToSExpr.parse(inBytes.toString(Charsets.UTF_8)).let { res ->
when (res) {
is StrToSExpr.ParseResult.Error -> error("Error [${res.pos.line}:${res.pos.char}] - ${res.pos}")
is StrToSExpr.ParseResult.Success -> SExprToAst.toScript(SExpr.Multi(res.vals))
}
}
"wasm" ->
Script(listOf(Script.Cmd.Module(BinaryToAst.toModule(
ByteReader.InputStream(inBytes.inputStream())), null)))
else -> error("Unknown in format '$inFormat'")
}
}
data class Args(
val inFile: String,
val inFormat: String,

View File

@ -9,6 +9,9 @@ import org.objectweb.asm.tree.*
import org.objectweb.asm.util.TraceClassVisitor
import java.io.PrintWriter
import java.io.StringWriter
import java.lang.reflect.Constructor
import java.lang.reflect.Executable
import java.lang.reflect.Method
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KProperty
@ -25,6 +28,7 @@ fun KFunction<*>.invokeStatic() =
fun KFunction<*>.invokeVirtual() =
MethodInsnNode(Opcodes.INVOKEVIRTUAL, this.declarer.ref.asmName, this.name, this.asmDesc, false)
@Suppress("NOTHING_TO_INLINE")
inline fun <T : Function<*>> forceFnType(fn: T) = fn as KFunction<*>
val KClass<*>.const: LdcInsnNode get() = (if (this == Void::class) Void.TYPE else this.java).const
@ -60,6 +64,13 @@ val Class<*>.valueType: Node.Type.Value? get() = when (this) {
else -> error("Unrecognized value type class: $this")
}
val Executable.ref: TypeRef get() = when (this) {
is Method -> TypeRef(Type.getType(this))
is Constructor<*> -> TypeRef(Type.getType(this))
else -> error("Unknown executable $this")
}
val KProperty<*>.declarer: Class<*> get() = this.javaField!!.declaringClass
val KProperty<*>.asmDesc: String get() = Type.getDescriptor(this.javaField!!.type)
@ -178,10 +189,12 @@ fun MethodNode.toAsmString(): String {
val Node.Type.Func.asmDesc: String get() =
(this.ret?.typeRef ?: Void::class.ref).asMethodRetDesc(*this.params.map { it.typeRef }.toTypedArray())
fun ClassNode.withComputedFramesAndMaxs(): ByteArray {
// TODO: compute maxs adds a bunch of NOPs for unreachable code
// See $func12 of block.wast. Is removing these worth the extra visit cycle?
val cw = ClassWriter(ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
fun ClassNode.withComputedFramesAndMaxs(
cw: ClassWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
): ByteArray {
// Note, compute maxs adds a bunch of NOPs for unreachable code.
// See $func12 of block.wast. I don't believe the extra time over the
// instructions to remove the NOPs is worth it.
this.accept(cw)
return cw.toByteArray()
}
@ -197,3 +210,7 @@ fun ByteArray.asClassNode(): ClassNode {
ClassReader(this).accept(newNode, 0)
return newNode
}
fun ByteArray.chunked(v: Int) = (0 until size step v).asSequence().map {
copyOfRange(it, (it + v).takeIf { it < size } ?: size)
}

View File

@ -1,10 +1,17 @@
package asmble.compile.jvm
import asmble.annotation.WasmExport
import asmble.annotation.WasmExternalKind
import asmble.annotation.WasmImport
import asmble.annotation.WasmModule
import asmble.ast.Node
import asmble.io.AstToBinary
import asmble.io.ByteWriter
import asmble.util.Either
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.tree.*
import java.io.ByteArrayOutputStream
import java.lang.invoke.MethodHandle
import java.lang.invoke.MethodHandles
@ -13,19 +20,21 @@ open class AstToAsm {
fun fromModule(ctx: ClsContext) {
// Invoke dynamic among other things
ctx.cls.superName = Object::class.ref.asmName
ctx.cls.version = Opcodes.V1_7
ctx.cls.version = Opcodes.V1_8
ctx.cls.access += Opcodes.ACC_PUBLIC
addFields(ctx)
addConstructors(ctx)
addFuncs(ctx)
addExports(ctx)
addAnnotations(ctx)
}
fun addFields(ctx: ClsContext) {
// Mem field if present
// Mem field if present, adds `private final field memory` to
if (ctx.hasMemory)
ctx.cls.fields.add(FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, "memory",
ctx.mem.memType.asmDesc, null, null))
ctx.cls.fields.add(
FieldNode((Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL), "memory", ctx.mem.memType.asmDesc, null, null)
)
// Table field if present...
// Private final for now, but likely won't be final in future versions supporting
// mutable tables, may be not even a table but a list (and final)
@ -33,17 +42,19 @@ open class AstToAsm {
ctx.cls.fields.add(FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, "table",
Array<MethodHandle>::class.ref.asmDesc, null, null))
// Now all method imports as method handles
// TODO: why does this fail with asm-debug-all but not with just regular asm?
ctx.cls.fields.addAll(ctx.importFuncs.indices.map {
FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.funcName(it),
MethodHandle::class.ref.asmDesc, null, null)
})
// Now all import globals as getter (and maybe setter) method handles
ctx.cls.fields.addAll(ctx.importGlobals.mapIndexed { index, import ->
if ((import.kind as Node.Import.Kind.Global).type.mutable) throw CompileErr.MutableGlobalImport(index)
FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalGetterFieldName(index),
val getter = FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalGetterFieldName(index),
MethodHandle::class.ref.asmDesc, null, null)
})
if (!(import.kind as Node.Import.Kind.Global).type.mutable) listOf(getter)
else listOf(getter, FieldNode(
Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalSetterFieldName(index),
MethodHandle::class.ref.asmDesc, null, null))
}.flatten())
// Now all non-import globals
ctx.cls.fields.addAll(ctx.mod.globals.mapIndexed { index, global ->
val access = Opcodes.ACC_PRIVATE + if (!global.type.mutable) Opcodes.ACC_FINAL else 0
@ -86,7 +97,7 @@ open class AstToAsm {
func = initializeConstructorTables(ctx, func, 0)
func = executeConstructorStartFunction(ctx, func, 0)
func = func.addInsns(InsnNode(Opcodes.RETURN))
ctx.cls.methods.add(func.toMethodNode())
ctx.cls.methods.add(toConstructorNode(ctx, func))
}
fun addMaxMemConstructor(ctx: ClsContext) {
@ -108,7 +119,7 @@ open class AstToAsm {
MethodInsnNode(Opcodes.INVOKESPECIAL, ctx.thisRef.asmName, "<init>", desc, false),
InsnNode(Opcodes.RETURN)
)
ctx.cls.methods.add(func.toMethodNode())
ctx.cls.methods.add(toConstructorNode(ctx, func))
}
fun addMemClassConstructor(ctx: ClsContext) {
@ -149,7 +160,7 @@ open class AstToAsm {
func = executeConstructorStartFunction(ctx, func, 1)
func = func.addInsns(InsnNode(Opcodes.RETURN))
ctx.cls.methods.add(func.toMethodNode())
ctx.cls.methods.add(toConstructorNode(ctx, func))
}
fun addMemDefaultConstructor(ctx: ClsContext) {
@ -168,24 +179,100 @@ open class AstToAsm {
MethodInsnNode(Opcodes.INVOKESPECIAL, ctx.thisRef.asmName, "<init>", desc, false),
InsnNode(Opcodes.RETURN)
)
ctx.cls.methods.add(func.toMethodNode())
ctx.cls.methods.add(toConstructorNode(ctx, func))
}
fun constructorImportTypes(ctx: ClsContext) =
ctx.importFuncs.map { MethodHandle::class.ref } +
// We know it's only getters
ctx.importGlobals.map { MethodHandle::class.ref } +
ctx.mod.imports.filter { it.kind is Node.Import.Kind.Table }.map { Array<MethodHandle>::class.ref }
ctx.importGlobals.flatMap {
// If it's mutable, it also comes with a setter
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) listOf(MethodHandle::class.ref)
else listOf(MethodHandle::class.ref, MethodHandle::class.ref)
} + ctx.mod.imports.filter { it.kind is Node.Import.Kind.Table }.map { Array<MethodHandle>::class.ref }
fun toConstructorNode(ctx: ClsContext, func: Func) = mutableListOf<List<AnnotationNode>>().let { paramAnns ->
// If the first param is a mem class and imported, add annotation
// Otherwise if it is a mem class and not-imported or an int, no annotations
// Otherwise do nothing because the rest of the params are imports
func.params.firstOrNull()?.also { firstParam ->
if (firstParam == Int::class.ref) {
paramAnns.add(emptyList())
} else if (firstParam == ctx.mem.memType) {
val importMem = ctx.mod.imports.find { it.kind is Node.Import.Kind.Memory }
if (importMem == null) paramAnns.add(emptyList())
else paramAnns.add(listOf(importAnnotation(ctx, importMem)))
}
}
// All non-mem imports one after another
ctx.importFuncs.forEach { paramAnns.add(listOf(importAnnotation(ctx, it))) }
ctx.importGlobals.forEach {
paramAnns.add(listOf(importAnnotation(ctx, it)))
// There are two annotations here if it's mutable
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == true)
paramAnns.add(listOf(importAnnotation(ctx, it).also {
it.values.add("globalSetter")
it.values.add(true)
}))
}
ctx.mod.imports.forEach {
if (it.kind is Node.Import.Kind.Table) paramAnns.add(listOf(importAnnotation(ctx, it)))
}
func.toMethodNode().also { it.visibleParameterAnnotations = paramAnns.toTypedArray() }
}
fun importAnnotation(ctx: ClsContext, import: Node.Import) = AnnotationNode(WasmImport::class.ref.asmDesc).also {
it.values = mutableListOf<Any>("module", import.module, "field", import.field)
fun addValues(desc: String, limits: Node.ResizableLimits? = null) {
it.values.add("desc")
it.values.add(desc)
if (limits != null) {
it.values.add("resizableLimitInitial")
it.values.add(limits.initial)
if (limits.maximum != null) {
it.values.add("resizableLimitMaximum")
it.values.add(limits.maximum)
}
}
it.values.add("kind")
it.values.add(arrayOf(WasmExternalKind::class.ref.asmDesc, when (import.kind) {
is Node.Import.Kind.Func -> WasmExternalKind.FUNCTION.name
is Node.Import.Kind.Table -> WasmExternalKind.TABLE.name
is Node.Import.Kind.Memory -> WasmExternalKind.MEMORY.name
is Node.Import.Kind.Global -> WasmExternalKind.GLOBAL.name
}))
}
when (import.kind) {
is Node.Import.Kind.Func ->
ctx.typeAtIndex(import.kind.typeIndex).let { addValues(it.asmDesc) }
is Node.Import.Kind.Table ->
addValues(Array<MethodHandle>::class.ref.asMethodRetDesc(), import.kind.type.limits)
is Node.Import.Kind.Memory ->
addValues(ctx.mem.memType.asMethodRetDesc(), import.kind.type.limits)
is Node.Import.Kind.Global ->
addValues(import.kind.type.contentType.typeRef.asMethodRetDesc())
}
}
fun setConstructorGlobalImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.importGlobals.indices.fold(func) { func, importIndex ->
ctx.importGlobals.foldIndexed(func to ctx.importFuncs.size + paramsBeforeImports) {
importIndex, (func, importParamOffset), import ->
// Always a getter handle
func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, ctx.importFuncs.size + importIndex + paramsBeforeImports + 1),
VarInsnNode(Opcodes.ALOAD, importParamOffset + 1),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName,
ctx.importGlobalGetterFieldName(importIndex), MethodHandle::class.ref.asmDesc)
)
).let { func ->
// If it's mutable, it has a second setter handle
if ((import.kind as? Node.Import.Kind.Global)?.type?.mutable == false) func to importParamOffset + 1
else func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, importParamOffset + 2),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName,
ctx.importGlobalSetterFieldName(importIndex), MethodHandle::class.ref.asmDesc)
) to importParamOffset + 2
}
}.first
fun setConstructorFunctionImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.importFuncs.indices.fold(func) { func, importIndex ->
@ -199,7 +286,10 @@ open class AstToAsm {
fun setConstructorTableImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
if (ctx.mod.imports.none { it.kind is Node.Import.Kind.Table }) func else {
val importIndex = ctx.importFuncs.size + ctx.importGlobals.size + paramsBeforeImports + 1
val importIndex = ctx.importFuncs.size +
// Mutable global imports have setters and take up two spots
ctx.importGlobals.sumBy { if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == true) 2 else 1 } +
paramsBeforeImports + 1
func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, importIndex),
@ -211,7 +301,7 @@ open class AstToAsm {
fun initializeConstructorGlobals(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.mod.globals.foldIndexed(func) { index, func, global ->
// In the MVP, we can trust the init is constant stuff and a single instr
if (global.init.size > 1) throw CompileErr.GlobalInitNotConstant(index)
if (global.init.size != 1) throw CompileErr.GlobalInitNotConstant(index)
func.addInsns(VarInsnNode(Opcodes.ALOAD, 0)).
addInsns(
global.init.firstOrNull().let {
@ -237,11 +327,14 @@ open class AstToAsm {
global.type.contentType.typeRef,
refGlobalKind.type.contentType.typeRef
)
val paramOffset = ctx.importFuncs.size + paramsBeforeImports + 1 +
ctx.importGlobals.take(it.index).sumBy {
// Immutable jumps 1, mutable jumps 2
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1
else 2
}
listOf(
VarInsnNode(
Opcodes.ALOAD,
ctx.importFuncs.size + it.index + paramsBeforeImports + 1
),
VarInsnNode(Opcodes.ALOAD, paramOffset),
MethodInsnNode(
Opcodes.INVOKEVIRTUAL,
MethodHandle::class.ref.asmName,
@ -294,7 +387,10 @@ open class AstToAsm {
// Otherwise, it was imported and we can set the elems on the imported one
// from the parameter
// TODO: I think this is a security concern and bad practice, may revisit
val importIndex = ctx.importFuncs.size + ctx.importGlobals.size + paramsBeforeImports + 1
val importIndex = ctx.importFuncs.size + ctx.importGlobals.sumBy {
// Immutable is 1, mutable is 2
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1 else 2
} + paramsBeforeImports + 1
return func.addInsns(VarInsnNode(Opcodes.ALOAD, importIndex)).
let { func -> addElemsToTable(ctx, func, paramsBeforeImports) }.
// Remove the array that's still there
@ -412,6 +508,13 @@ open class AstToAsm {
}
}
fun exportAnnotation(export: Node.Export) = AnnotationNode(WasmExport::class.ref.asmDesc).also {
it.values = listOf(
"value", export.field,
"kind", arrayOf(WasmExternalKind::class.ref.asmDesc, export.kind.name)
)
}
fun addExportFunc(ctx: ClsContext, export: Node.Export) {
val funcType = ctx.funcTypeAtIndex(export.index)
val method = MethodNode(Opcodes.ACC_PUBLIC, export.field.javaIdent, funcType.asmDesc, null, null)
@ -453,6 +556,7 @@ open class AstToAsm {
Node.Type.Value.F32 -> Opcodes.FRETURN
Node.Type.Value.F64 -> Opcodes.DRETURN
}))
method.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(method)
}
@ -462,27 +566,58 @@ open class AstToAsm {
is Either.Left -> (global.v.kind as Node.Import.Kind.Global).type
is Either.Right -> global.v.type
}
if (type.mutable) throw CompileErr.MutableGlobalExport(export.index)
// Create a simple getter
val method = MethodNode(Opcodes.ACC_PUBLIC, "get" + export.field.javaIdent.capitalize(),
val getter = MethodNode(Opcodes.ACC_PUBLIC, "get" + export.field.javaIdent.capitalize(),
"()" + type.contentType.typeRef.asmDesc, null, null)
method.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) method.addInsns(
getter.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) getter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName,
ctx.importGlobalGetterFieldName(export.index), MethodHandle::class.ref.asmDesc),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact",
"()" + type.contentType.typeRef.asmDesc, false)
) else method.addInsns(
) else getter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, ctx.globalName(export.index),
type.contentType.typeRef.asmDesc)
)
method.addInsns(InsnNode(when (type.contentType) {
getter.addInsns(InsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.IRETURN
Node.Type.Value.I64 -> Opcodes.LRETURN
Node.Type.Value.F32 -> Opcodes.FRETURN
Node.Type.Value.F64 -> Opcodes.DRETURN
}))
ctx.cls.methods.plusAssign(method)
getter.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(getter)
// If mutable, create simple setter
if (type.mutable) {
val setter = MethodNode(Opcodes.ACC_PUBLIC, "set" + export.field.javaIdent.capitalize(),
"(${type.contentType.typeRef.asmDesc})V", null, null)
setter.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) setter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName,
ctx.importGlobalSetterFieldName(export.index), MethodHandle::class.ref.asmDesc),
VarInsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.ILOAD
Node.Type.Value.I64 -> Opcodes.LLOAD
Node.Type.Value.F32 -> Opcodes.FLOAD
Node.Type.Value.F64 -> Opcodes.DLOAD
}, 1),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact",
"(${type.contentType.typeRef.asmDesc})V", false),
InsnNode(Opcodes.RETURN)
) else setter.addInsns(
VarInsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.ILOAD
Node.Type.Value.I64 -> Opcodes.LLOAD
Node.Type.Value.F32 -> Opcodes.FLOAD
Node.Type.Value.F64 -> Opcodes.DLOAD
}, 1),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName, ctx.globalName(export.index),
type.contentType.typeRef.asmDesc),
InsnNode(Opcodes.RETURN)
)
setter.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(setter)
}
}
fun addExportMemory(ctx: ClsContext, export: Node.Export) {
@ -495,6 +630,7 @@ open class AstToAsm {
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, "memory", ctx.mem.memType.asmDesc),
InsnNode(Opcodes.ARETURN)
)
method.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(method)
}
@ -508,6 +644,7 @@ open class AstToAsm {
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, "table", Array<MethodHandle>::class.ref.asmDesc),
InsnNode(Opcodes.ARETURN)
)
method.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(method)
}
@ -517,5 +654,22 @@ open class AstToAsm {
})
}
fun addAnnotations(ctx: ClsContext) {
val annotationVals = mutableListOf<Any>()
ctx.modName?.let { annotationVals.addAll(listOf("name", it)) }
if (ctx.includeBinary) {
// We are going to store this as a string of bytes in an annotation on the class. The linker
// used to use this, but no longer does so it is opt-in for others to use. We choose to use an
// annotation instead of an attribute for the same reasons Scala chose to make the switch in
// 2.8+: Easier runtime reflection despite some size cost.
annotationVals.addAll(listOf("binary", ByteArrayOutputStream().also {
ByteWriter.OutputStream(it).also { AstToBinary.fromModule(it, ctx.mod) }
}.toByteArray().toString(Charsets.ISO_8859_1)))
}
ctx.cls.visibleAnnotations = listOf(
AnnotationNode(WasmModule::class.ref.asmDesc).also { it.values = annotationVals }
)
}
companion object : AstToAsm()
}

View File

@ -46,19 +46,25 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
let(buildOffset).popExpecting(Int::class.ref).
addInsns(
forceFnType<ByteBuffer.(Int) -> Buffer>(ByteBuffer::position).invokeVirtual(),
TypeInsnNode(Opcodes.CHECKCAST, memType.asmName),
// TODO: Is there a cheaper bulk approach instead of manually building
// a byte array? What's the harm of using a String in the constant pool instead?
bytes.size.const,
IntInsnNode(Opcodes.NEWARRAY, Opcodes.T_BYTE)
).
addInsns(bytes.withIndex().flatMap { (index, byte) ->
listOf(InsnNode(Opcodes.DUP), index.const, byte.toInt().const, InsnNode(Opcodes.BASTORE))
}).
addInsns(
TypeInsnNode(Opcodes.CHECKCAST, memType.asmName)
).addInsns(
// We're going to do this as an LDC string in ISO-8859 and read it back at runtime. However,
// due to JVM limits, we can't have a string > 65536 chars. We chunk into 16300 because when
// converting to UTF8 const it can be up to 4 bytes per char, so this makes sure it doesn't
// overflow.
bytes.chunked(16300).flatMap { bytes ->
sequenceOf(
LdcInsnNode(bytes.toString(Charsets.ISO_8859_1)),
LdcInsnNode("ISO-8859-1"),
// Ug, can't do func refs on native types here...
MethodInsnNode(Opcodes.INVOKEVIRTUAL, String::class.ref.asmName,
"getBytes", "(Ljava/lang/String;)[B", false),
0.const,
bytes.size.const,
forceFnType<ByteBuffer.(ByteArray, Int, Int) -> ByteBuffer>(ByteBuffer::put).invokeVirtual(),
forceFnType<ByteBuffer.(ByteArray, Int, Int) -> ByteBuffer>(ByteBuffer::put).invokeVirtual()
)
}.toList()
).addInsns(
InsnNode(Opcodes.POP)
)
@ -252,5 +258,7 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
}
}
override val storeLeavesMemOnStack get() = true
companion object : ByteBufferMem()
}

View File

@ -16,6 +16,7 @@ data class ClsContext(
val mod: Node.Module,
val cls: ClassNode = ClassNode().also { it.name = (packageName.replace('.', '/') + "/$className").trimStart('/') },
val mem: Mem = ByteBufferMem,
val modName: String? = null,
val reworker: InsnReworker = InsnReworker,
val logger: Logger = Logger.Print(Logger.Level.OFF),
val funcBuilder: FuncBuilder = FuncBuilder,
@ -26,7 +27,8 @@ data class ClsContext(
val preventMemIndexOverflow: Boolean = false,
val accurateNanBits: Boolean = true,
val checkSignedDivIntegerOverflow: Boolean = true,
val jumpTableChunkSize: Int = 5000
val jumpTableChunkSize: Int = 5000,
val includeBinary: Boolean = false
) : Logger by logger {
val importFuncs: List<Node.Import> by lazy { mod.imports.filter { it.kind is Node.Import.Kind.Func } }
val importGlobals: List<Node.Import> by lazy { mod.imports.filter { it.kind is Node.Import.Kind.Global } }
@ -37,6 +39,24 @@ data class ClsContext(
val hasTable: Boolean by lazy {
mod.tables.isNotEmpty() || mod.imports.any { it.kind is Node.Import.Kind.Table }
}
val dedupedFuncNames: Map<Int, String>? by lazy {
// Consider all exports as seen
val seen = mod.exports.flatMap { export ->
when {
export.kind == Node.ExternalKind.FUNCTION -> listOf(export.field.javaIdent)
// Just to make it easy, consider all globals as having setters
export.kind == Node.ExternalKind.GLOBAL ->
export.field.javaIdent.capitalize().let { listOf("get$it", "set$it") }
else -> listOf("get" + export.field.javaIdent.capitalize())
}
}.toMutableSet()
mod.names?.funcNames?.toList()?.sortedBy { it.first }?.map { (index, origName) ->
var name = origName.javaIdent
var nameIndex = 0
while (!seen.add(name)) name = origName.javaIdent + (nameIndex++)
index to name
}?.toMap()
}
fun assertHasMemory() { if (!hasMemory) throw CompileErr.UnknownMemory(0) }
@ -69,7 +89,7 @@ data class ClsContext(
fun importGlobalGetterFieldName(index: Int) = "import\$get" + globalName(index)
fun importGlobalSetterFieldName(index: Int) = "import\$set" + globalName(index)
fun globalName(index: Int) = "\$global$index"
fun funcName(index: Int) = "\$func$index"
fun funcName(index: Int) = dedupedFuncNames?.get(index) ?: "\$func$index"
private fun syntheticFunc(
nameSuffix: String,

View File

@ -10,6 +10,7 @@ sealed class CompileErr(message: String, cause: Throwable? = null) : RuntimeExce
val actual: TypeRef?
) : CompileErr("Expected any type of ${Arrays.toString(expected)}, got $actual") {
override val asmErrString get() = "type mismatch"
override val asmErrStrings get() = listOf(asmErrString, "mismatching label")
}
class StackInjectionMismatch(
@ -48,7 +49,7 @@ sealed class CompileErr(message: String, cause: Throwable? = null) : RuntimeExce
override val asmErrString get() = "type mismatch"
}
class IfThenValueWithoutElse() : CompileErr("If has value but no else clause") {
class IfThenValueWithoutElse : CompileErr("If has value but no else clause") {
override val asmErrString get() = "type mismatch"
}
@ -83,11 +84,12 @@ sealed class CompileErr(message: String, cause: Throwable? = null) : RuntimeExce
}
class UnknownMemory(val index: Int) : CompileErr("No memory present at index $index") {
override val asmErrString get() = "unknown memory"
override val asmErrString get() = "unknown memory $index"
}
class UnknownTable(val index: Int) : CompileErr("No table present at index $index") {
override val asmErrString get() = "unknown table"
override val asmErrStrings get() = listOf(asmErrString, "unknown table $index")
}
class UnknownType(val index: Int) : CompileErr("No type present for index $index") {
@ -100,22 +102,11 @@ sealed class CompileErr(message: String, cause: Throwable? = null) : RuntimeExce
override val asmErrString get() = "global is immutable"
}
class MutableGlobalImport(
val index: Int
) : CompileErr("Attempted to import mutable global at index $index") {
override val asmErrString get() = "mutable globals cannot be imported"
}
class MutableGlobalExport(
val index: Int
) : CompileErr("Attempted to export global $index which is mutable") {
override val asmErrString get() = "mutable globals cannot be exported"
}
class GlobalInitNotConstant(
val index: Int
) : CompileErr("Expected init for global $index to be constant") {
) : CompileErr("Expected init for global $index to be single constant value") {
override val asmErrString get() = "constant expression required"
override val asmErrStrings get() = listOf(asmErrString, "type mismatch")
}
class OffsetNotConstant : CompileErr("Expected offset to be constant") {

View File

@ -4,6 +4,24 @@ import asmble.ast.Node
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.*
/**
* Jvm representation of a function.
*
* @param name Name of the fn.
* @param params List of parameters of the fn.
* @param ret Type of the fn returner value.
* @param access The value of the access_flags item is a mask of flags used to
* denote access permissions to and properties of this class or
* interface [https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.1-200-E.1].
* @param insns List of nodes that represents a bytecode instruction.
* @param stack A stack of operand types. Mirror of the operand stack(jvm stack)
* where types of operands instead operands.
* @param blockStack List of blocks of code
* @param ifStack Contains index of [org.objectweb.asm.tree.JumpInsnNode] that
* has a null label initially
* @param lastStackIsMemLeftover If there is the memory on the stack and we need it
* in the future, we mark it as leftover and reuse
*/
data class Func(
val name: String,
val params: List<TypeRef> = emptyList(),
@ -12,8 +30,8 @@ data class Func(
val insns: List<AbstractInsnNode> = emptyList(),
val stack: List<TypeRef> = emptyList(),
val blockStack: List<Block> = emptyList(),
// Contains index of JumpInsnNode that has a null label initially
val ifStack: List<Int> = emptyList()
val ifStack: List<Int> = emptyList(),
val lastStackIsMemLeftover: Boolean = false
) {
val desc: String get() = ret.asMethodRetDesc(*params.toTypedArray())
@ -109,10 +127,11 @@ data class Func(
}
}
fun pushBlock(insn: Node.Instr, labelType: Node.Type.Value?, endType: Node.Type.Value?) =
/** Creates new block with specified instruction and pushes it into the blockStack.*/
fun pushBlock(insn: Node.Instr, labelType: Node.Type.Value?, endType: Node.Type.Value?): Func =
pushBlock(insn, listOfNotNull(labelType?.typeRef), listOfNotNull(endType?.typeRef))
fun pushBlock(insn: Node.Instr, labelTypes: List<TypeRef>, endTypes: List<TypeRef>) =
fun pushBlock(insn: Node.Instr, labelTypes: List<TypeRef>, endTypes: List<TypeRef>): Func =
copy(blockStack = blockStack + Block(insn, insns.size, stack, labelTypes, endTypes))
fun popBlock() = copy(blockStack = blockStack.dropLast(1)) to blockStack.last()
@ -126,6 +145,22 @@ data class Func(
fun popIf() = copy(ifStack = ifStack.dropLast(1)) to peekIf()
/**
* Representation of code block.
*
* Blocks are composed of matched pairs of `block ... end` instructions, loops
* with matched pairs of `loop ... end` instructions, and ifs with either
* `if ... end` or if ... else ... end sequences. For each of these constructs
* the instructions in the ellipsis are said to be enclosed in the construct.
*
* @param isns Start instruction of this block, might be a 'Block', 'Loop'
* or 'If'
* @param startIndex Index of start instruction of this block in list of all
* instructions
* @param origStack Current block stack of operand types.
* @param labelTypes A type of label for this block
* @param endTypes A type of block return value
*/
class Block(
val insn: Node.Instr,
val startIndex: Int,

View File

@ -14,25 +14,31 @@ import java.lang.invoke.MethodHandle
// TODO: modularize
open class FuncBuilder {
fun fromFunc(ctx: ClsContext, f: Node.Func, index: Int): Func {
// TODO: validate local size?
// TODO: initialize non-param locals?
/**
* Converts wasm AST [asmble.ast.Node.Func] to Jvm bytecode representation [asmble.compile.jvm.Func].
*
* @param ctx A Global context for converting.
* @param fn AST of wasm fn.
* @param index Fn index, used for generating fn name
*/
fun fromFunc(ctx: ClsContext, fn: Node.Func, index: Int): Func {
ctx.debug { "Building function ${ctx.funcName(index)}" }
ctx.trace { "Function ast:\n${SExprToStr.fromSExpr(AstToSExpr.fromFunc(f))}" }
ctx.trace { "Function ast:\n${SExprToStr.fromSExpr(AstToSExpr.fromFunc(fn))}" }
var func = Func(
access = Opcodes.ACC_PRIVATE,
name = ctx.funcName(index),
params = f.type.params.map(Node.Type.Value::typeRef),
ret = f.type.ret?.let(Node.Type.Value::typeRef) ?: Void::class.ref
params = fn.type.params.map(Node.Type.Value::typeRef),
ret = fn.type.ret?.let(Node.Type.Value::typeRef) ?: Void::class.ref
)
// Rework the instructions
val reworkedInsns = ctx.reworker.rework(ctx, f)
val reworkedInsns = ctx.reworker.rework(ctx, fn)
// Start the implicit block
func = func.pushBlock(Node.Instr.Block(f.type.ret), f.type.ret, f.type.ret)
func = func.pushBlock(Node.Instr.Block(fn.type.ret), fn.type.ret, fn.type.ret)
// Create the context
val funcCtx = FuncContext(
cls = ctx,
node = f,
node = fn,
insns = reworkedInsns,
memIsLocalVar =
ctx.reworker.nonAdjacentMemAccesses(reworkedInsns) >= ctx.nonAdjacentMemAccessesRequiringLocalVar
@ -48,9 +54,9 @@ open class FuncBuilder {
// Add all instructions
ctx.debug { "Applying insns for function ${ctx.funcName(index)}" }
// All functions have an implicit block
func = funcCtx.insns.foldIndexed(func) { index, func, insn ->
func = funcCtx.insns.foldIndexed(func) { idx, f, insn ->
ctx.debug { "Applying insn $insn" }
val ret = applyInsn(funcCtx, func, insn, index)
val ret = applyInsn(funcCtx, f, insn, idx)
ctx.trace { "Resulting stack: ${ret.stack}"}
ret
}
@ -58,11 +64,11 @@ open class FuncBuilder {
// End the implicit block
val implicitBlock = func.currentBlock
func = applyEnd(funcCtx, func)
f.type.ret?.typeRef?.also { func = func.popExpecting(it, implicitBlock) }
fn.type.ret?.typeRef?.also { func = func.popExpecting(it, implicitBlock) }
// If the last instruction does not terminate, add the expected return
if (func.insns.isEmpty() || !func.insns.last().isTerminating) {
func = func.addInsns(InsnNode(when (f.type.ret) {
func = func.addInsns(InsnNode(when (fn.type.ret) {
null -> Opcodes.RETURN
Node.Type.Value.I32 -> Opcodes.IRETURN
Node.Type.Value.I64 -> Opcodes.LRETURN
@ -74,8 +80,10 @@ open class FuncBuilder {
}
fun applyInsn(ctx: FuncContext, fn: Func, i: Insn, index: Int) = when (i) {
is Insn.Node ->
applyNodeInsn(ctx, fn, i.insn, index)
is Insn.ImportFuncRefNeededOnStack ->
// Func refs are method handle fields
fn.addInsns(
@ -83,6 +91,7 @@ open class FuncBuilder {
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.thisRef.asmName,
ctx.cls.funcName(i.index), MethodHandle::class.ref.asmDesc)
).push(MethodHandle::class.ref)
is Insn.ImportGlobalSetRefNeededOnStack ->
// Import setters are method handle fields
fn.addInsns(
@ -90,13 +99,17 @@ open class FuncBuilder {
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.thisRef.asmName,
ctx.cls.importGlobalSetterFieldName(i.index), MethodHandle::class.ref.asmDesc)
).push(MethodHandle::class.ref)
is Insn.ThisNeededOnStack ->
// load a reference onto the stack from a local variable
fn.addInsns(VarInsnNode(Opcodes.ALOAD, 0)).push(ctx.cls.thisRef)
is Insn.MemNeededOnStack ->
putMemoryOnStackIfNecessary(ctx, fn)
putMemoryOnStack(ctx, fn)
}
fun applyNodeInsn(ctx: FuncContext, fn: Func, i: Node.Instr, index: Int) = when (i) {
is Node.Instr.Unreachable ->
fn.addInsns(UnsupportedOperationException::class.athrow("Unreachable")).markUnreachable()
is Node.Instr.Nop ->
@ -129,32 +142,29 @@ open class FuncBuilder {
fn.pop().let { (fn, popped) ->
fn.addInsns(InsnNode(if (popped.stackSize == 2) Opcodes.POP2 else Opcodes.POP))
}
is Node.Instr.Select ->
applySelectInsn(ctx, fn)
is Node.Instr.GetLocal ->
applyGetLocal(ctx, fn, i.index)
is Node.Instr.SetLocal ->
applySetLocal(ctx, fn, i.index)
is Node.Instr.TeeLocal ->
applyTeeLocal(ctx, fn, i.index)
is Node.Instr.GetGlobal ->
applyGetGlobal(ctx, fn, i.index)
is Node.Instr.SetGlobal ->
applySetGlobal(ctx, fn, i.index)
is Node.Instr.Select -> applySelectInsn(ctx, fn)
// Variable access
is Node.Instr.GetLocal -> applyGetLocal(ctx, fn, i.index)
is Node.Instr.SetLocal -> applySetLocal(ctx, fn, i.index)
is Node.Instr.TeeLocal -> applyTeeLocal(ctx, fn, i.index)
is Node.Instr.GetGlobal -> applyGetGlobal(ctx, fn, i.index)
is Node.Instr.SetGlobal -> applySetGlobal(ctx, fn, i.index)
// Memory operators
is Node.Instr.I32Load, is Node.Instr.I64Load, is Node.Instr.F32Load, is Node.Instr.F64Load,
is Node.Instr.I32Load8S, is Node.Instr.I32Load8U, is Node.Instr.I32Load16U, is Node.Instr.I32Load16S,
is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U, is Node.Instr.I64Load16S,
is Node.Instr.I64Load32S, is Node.Instr.I64Load32U ->
// TODO: why do I have to cast?
applyLoadOp(ctx, fn, i as Node.Instr.Args.AlignOffset)
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32 ->
applyStoreOp(ctx, fn, i as Node.Instr.Args.AlignOffset, index)
is Node.Instr.CurrentMemory ->
applyCurrentMemory(ctx, fn)
is Node.Instr.GrowMemory ->
applyGrowMemory(ctx, fn)
is Node.Instr.MemorySize ->
applyMemorySize(ctx, fn)
is Node.Instr.MemoryGrow ->
applyMemoryGrow(ctx, fn)
is Node.Instr.I32Const ->
fn.addInsns(i.value.const).push(Int::class.ref)
is Node.Instr.I64Const ->
@ -464,18 +474,18 @@ open class FuncBuilder {
fun popUntilStackSize(
ctx: FuncContext,
fn: Func,
func: Func,
block: Func.Block,
untilStackSize: Int,
keepLast: Boolean
): Func {
ctx.debug { "For block ${block.insn}, popping until stack size $untilStackSize, keeping last? $keepLast" }
// Just get the latest, don't actually pop...
val type = if (keepLast) fn.pop().second else null
return (0 until Math.max(0, fn.stack.size - untilStackSize)).fold(fn) { fn, _ ->
val type = if (keepLast) func.pop().second else null
return (0 until Math.max(0, func.stack.size - untilStackSize)).fold(func) { fn, _ ->
// Essentially swap and pop if they want to keep the latest
(if (type != null && fn.stack.size > 1) fn.stackSwap(block) else fn).let { fn ->
fn.pop(block).let { (fn, poppedType) ->
(if (type != null && fn.stack.size > 1) fn.stackSwap(block) else fn).let { f ->
f.pop(block).let { (fn, poppedType) ->
fn.addInsns(InsnNode(if (poppedType.stackSize == 2) Opcodes.POP2 else Opcodes.POP))
}
}
@ -497,8 +507,15 @@ open class FuncBuilder {
// Must at least have the item on the stack that the block expects if it expects something
val needsPopBeforeJump = needsToPopBeforeJumping(ctx, fn, block)
val toLabel = if (needsPopBeforeJump) LabelNode() else block.requiredLabel
fn.addInsns(JumpInsnNode(Opcodes.IFNE, toLabel)).let { fn ->
block.endTypes.firstOrNull()?.let { fn.peekExpecting(it) }
fn.addInsns(JumpInsnNode(Opcodes.IFNE, toLabel)).let { origFn ->
val fn = block.endTypes.firstOrNull()?.let { endType ->
// We have to pop the stack and re-push to get the right type after unreachable here...
// Ref: https://github.com/WebAssembly/spec/pull/537
// Update: but only if it's not a loop
// Ref: https://github.com/WebAssembly/spec/pull/610
if (block.insn is Node.Instr.Loop) origFn
else origFn.popExpecting(endType).push(endType)
} ?: origFn
if (needsPopBeforeJump) buildPopBeforeJump(ctx, fn, block, toLabel)
else fn
}
@ -1058,27 +1075,37 @@ open class FuncBuilder {
).push(Int::class.ref)
}
fun applyGrowMemory(ctx: FuncContext, fn: Func) =
fun applyMemoryGrow(ctx: FuncContext, fn: Func) =
// Grow mem is a special case where the memory ref is already pre-injected on
// the stack before this call. Result is an int.
ctx.cls.assertHasMemory().let {
ctx.cls.mem.growMemory(ctx, fn)
}
fun applyCurrentMemory(ctx: FuncContext, fn: Func) =
fun applyMemorySize(ctx: FuncContext, fn: Func) =
// Curr mem is not specially injected, so we have to put the memory on the
// stack since we need it
ctx.cls.assertHasMemory().let {
putMemoryOnStackIfNecessary(ctx, fn).let { fn -> ctx.cls.mem.currentMemory(ctx, fn) }
putMemoryOnStack(ctx, fn).let { fn -> ctx.cls.mem.currentMemory(ctx, fn) }
}
/**
* Store is a special case where the memory ref is already pre-injected on
* the stack before this call. But it can have a memory leftover on the stack
* so we pop it if we need to
*/
fun applyStoreOp(ctx: FuncContext, fn: Func, insn: Node.Instr.Args.AlignOffset, insnIndex: Int) =
// Store is a special case where the memory ref is already pre-injected on
// the stack before this call. But it can have a memory leftover on the stack
// so we pop it if we need to
ctx.cls.assertHasMemory().let {
ctx.cls.mem.storeOp(ctx, fn, insn).let { fn ->
popMemoryIfNecessary(ctx, fn, ctx.insns.getOrNull(insnIndex + 1))
// As a special case, if this leaves the mem on the stack
// and we need it in the future, we mark it as leftover and
// reuse
if (!ctx.cls.mem.storeLeavesMemOnStack) fn else ctx.insns.getOrNull(insnIndex + 1).let { nextInsn ->
if (nextInsn is Insn.MemNeededOnStack) {
fn.peekExpecting(ctx.cls.mem.memType)
fn.copy(lastStackIsMemLeftover = true)
} else fn.popExpecting(ctx.cls.mem.memType).addInsns(InsnNode(Opcodes.POP))
}
}
}
@ -1089,8 +1116,9 @@ open class FuncBuilder {
ctx.cls.mem.loadOp(ctx, fn, insn)
}
fun putMemoryOnStackIfNecessary(ctx: FuncContext, fn: Func) =
if (fn.stack.lastOrNull() == ctx.cls.mem.memType) fn
fun putMemoryOnStack(ctx: FuncContext, fn: Func) =
// Only put it if it's not already leftover
if (fn.lastStackIsMemLeftover) fn.copy(lastStackIsMemLeftover = false)
else if (ctx.memIsLocalVar)
// Assume it's just past the locals
fn.addInsns(VarInsnNode(Opcodes.ALOAD, ctx.actualLocalIndex(ctx.node.localsSize))).
@ -1100,19 +1128,6 @@ open class FuncBuilder {
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.thisRef.asmName, "memory", ctx.cls.mem.memType.asmDesc)
).push(ctx.cls.mem.memType)
fun popMemoryIfNecessary(ctx: FuncContext, fn: Func, nextInsn: Insn?) =
// We pop the mem if it's there and not a mem op next
if (fn.stack.lastOrNull() != ctx.cls.mem.memType) fn else {
val nextInstrRequiresMemOnStack = when (nextInsn) {
is Insn.Node -> nextInsn.insn is Node.Instr.Args.AlignOffset ||
nextInsn.insn is Node.Instr.CurrentMemory || nextInsn.insn is Node.Instr.GrowMemory
is Insn.MemNeededOnStack -> true
else -> false
}
if (nextInstrRequiresMemOnStack) fn
else fn.popExpecting(ctx.cls.mem.memType).addInsns(InsnNode(Opcodes.POP))
}
fun applySetGlobal(ctx: FuncContext, fn: Func, index: Int) = ctx.cls.globalAtIndex(index).let {
when (it) {
is Either.Left -> applyImportSetGlobal(ctx, fn, index, it.v.kind as Node.Import.Kind.Global)
@ -1268,11 +1283,11 @@ open class FuncBuilder {
}
}
fun applyReturnInsn(ctx: FuncContext, fn: Func): Func {
// If the current stakc is unreachable, we consider that our block since it
fun applyReturnInsn(ctx: FuncContext, func: Func): Func {
// If the current stack is unreachable, we consider that our block since it
// will pop properly.
val block = if (fn.currentBlock.unreachable) fn.currentBlock else fn.blockStack.first()
popForBlockEscape(ctx, fn, block).let { fn ->
val block = if (func.currentBlock.unreachable) func.currentBlock else func.blockStack.first()
popForBlockEscape(ctx, func, block).let { fn ->
return when (ctx.node.type.ret) {
null ->
fn.addInsns(InsnNode(Opcodes.RETURN))
@ -1284,9 +1299,9 @@ open class FuncBuilder {
fn.popExpecting(Float::class.ref, block).addInsns(InsnNode(Opcodes.FRETURN))
Node.Type.Value.F64 ->
fn.popExpecting(Double::class.ref, block).addInsns(InsnNode(Opcodes.DRETURN))
}.let { fn ->
if (fn.stack.isNotEmpty()) throw CompileErr.UnusedStackOnReturn(fn.stack)
fn.markUnreachable()
}.let { it ->
if (it.stack.isNotEmpty()) throw CompileErr.UnusedStackOnReturn(it.stack)
it.markUnreachable()
}
}
}

View File

@ -3,6 +3,15 @@ package asmble.compile.jvm
import asmble.ast.Node
import asmble.util.Logger
/**
* Jvm context of execution a function.
*
* @param cls Class execution context
* @param node Ast of this function
* @param insns A list of instructions
* @param memIsLocalVar If true then function use only local variables and don't load
* and store from memory.
*/
data class FuncContext(
val cls: ClsContext,
val node: Node.Func,

View File

@ -2,6 +2,9 @@ package asmble.compile.jvm
import asmble.ast.Node
/**
* Does some special manipulations with instruction.
*/
open class InsnReworker {
fun rework(ctx: ClsContext, func: Node.Func): List<Insn> {
@ -29,6 +32,7 @@ open class InsnReworker {
// Note, while walking backwards up the insns to find set/tee, we do skip entire
// blocks/loops/if+else combined with "end"
var neededEagerLocalIndices = emptySet<Int>()
fun addEagerSetIfNeeded(getInsnIndex: Int, localIndex: Int) {
// Within the param range? nothing needed
if (localIndex < func.type.params.size) return
@ -71,6 +75,7 @@ open class InsnReworker {
(insn is Insn.Node && insn.insn !is Node.Instr.SetLocal && insn.insn !is Node.Instr.TeeLocal)
if (needsEagerInit) neededEagerLocalIndices += localIndex
}
insns.forEachIndexed { index, insn ->
if (insn is Insn.Node && insn.insn is Node.Instr.GetLocal) addEagerSetIfNeeded(index, insn.insn.index)
}
@ -86,16 +91,25 @@ open class InsnReworker {
} + insns
}
/**
* Puts into instruction list needed instructions for pushing local variables
* into the stack and returns list of resulting instructions.
*
* @param ctx The Execution context
* @param insns The original instructions
*/
fun injectNeededStackVars(ctx: ClsContext, insns: List<Node.Instr>): List<Insn> {
ctx.trace { "Calculating places to inject needed stack variables" }
// How we do this:
// We run over each insn, and keep a running list of stack
// manips. If there is an insn that needs something so far back,
// manips. If there is an ins'n that needs something so far back,
// we calc where it needs to be added and keep a running list of
// insn inserts. Then at the end we settle up.
//
// Note, we don't do any injections for things like "this" if
// they aren't needed up the stack (e.g. a simple getfield can
// just aload 0 itself)
// just aload 0 itself). Also we take special care not to inject
// inside of an inner block.
// Each pair is first the amount of stack that is changed (0 is
// ignored, push is positive, pull is negative) then the index
@ -107,6 +121,14 @@ open class InsnReworker {
// guarantee the value will be in the right order if there are
// multiple for the same index
var insnsToInject = emptyMap<Int, List<Insn>>()
/**
* This function inject current instruction in stack.
*
* @param insn The instruction to inject
* @param count Number of step back on the stack that should we do for
* finding injection index.
*/
fun injectBeforeLastStackCount(insn: Insn, count: Int) {
ctx.trace { "Injecting $insn back $count stack values" }
fun inject(index: Int) {
@ -115,18 +137,53 @@ open class InsnReworker {
if (count == 0) return inject(stackManips.size)
var countSoFar = 0
var foundUnconditionalJump = false
var insideOfBlocks = 0
for ((amountChanged, insnIndex) in stackManips.asReversed()) {
// We have to skip inner blocks because we don't want to inject inside of there
if (insns[insnIndex] == Node.Instr.End) {
insideOfBlocks++
ctx.trace { "Found end, not injecting until before $insideOfBlocks more block start(s)" }
continue
}
// When we reach the top of a block, we need to decrement out inside count and
// if we are at 0, add the result of said block if necessary to the count.
if (insideOfBlocks > 0) {
// If it's not a block, just ignore it
val blockStackDiff = insns[insnIndex].let {
when (it) {
is Node.Instr.Block -> if (it.type == null) 0 else 1
is Node.Instr.Loop -> 0
is Node.Instr.If -> if (it.type == null) -1 else 0
else -> null
}
}
if (blockStackDiff != null) {
insideOfBlocks--
ctx.trace { "Found block begin, number of blocks we're still inside: $insideOfBlocks" }
// We're back on our block, change the count
if (insideOfBlocks == 0) countSoFar += blockStackDiff
}
if (insideOfBlocks > 0) continue
}
countSoFar += amountChanged
if (!foundUnconditionalJump) foundUnconditionalJump = insns[insnIndex].let { insn ->
if (!foundUnconditionalJump) {
foundUnconditionalJump = insns[insnIndex].let { insn ->
insn is Node.Instr.Br || insn is Node.Instr.BrTable ||
insn is Node.Instr.Unreachable || insn is Node.Instr.Return
}
if (countSoFar == count) return inject(insnIndex)
}
if (countSoFar == count) {
ctx.trace { "Found injection point as before insn #$insnIndex" }
return inject(insnIndex)
}
}
// Only consider it a failure if we didn't hit any unconditional jumps
if (!foundUnconditionalJump) throw CompileErr.StackInjectionMismatch(count, insn)
}
var traceStackSize = 0 // Used only for trace
// Go over each insn, determining where to inject
insns.forEachIndexed { index, insn ->
// Handle special injection cases
@ -160,24 +217,37 @@ open class InsnReworker {
is Node.Instr.I64Store32 ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 2)
// Grow memory requires "mem" before the single param
is Node.Instr.GrowMemory ->
is Node.Instr.MemoryGrow ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 1)
else -> { }
}
// Log some trace output
ctx.trace {
insnStackDiff(ctx, insn).let {
traceStackSize += it
"Stack diff is $it for insn #$index $insn, stack size now: $traceStackSize"
}
}
// Add the current diff
ctx.trace { "Stack diff is ${insnStackDiff(ctx, insn)} for $insn" }
stackManips += insnStackDiff(ctx, insn) to index
}
// Build resulting list
return insns.foldIndexed(emptyList<Insn>()) { index, ret, insn ->
return insns.foldIndexed(emptyList()) { index, ret, insn ->
val injections = insnsToInject[index] ?: emptyList()
ret + injections + Insn.Node(insn)
}
}
fun insnStackDiff(ctx: ClsContext, insn: Node.Instr) = when (insn) {
/**
* Calculate stack difference after calling instruction current instruction.
* Returns the difference from stack cursor position before instruction and after.
* `N = PUSH_OPS - POP_OPS.` '-n' mean that POP operation will be more than PUSH.
* If '0' then stack won't changed.
*/
fun insnStackDiff(ctx: ClsContext, insn: Node.Instr): Int = when (insn) {
is Node.Instr.Unreachable, is Node.Instr.Nop, is Node.Instr.Block,
is Node.Instr.Loop, is Node.Instr.If, is Node.Instr.Else,
is Node.Instr.End, is Node.Instr.Br, is Node.Instr.BrIf,
@ -196,17 +266,17 @@ open class InsnReworker {
is Node.Instr.GetLocal -> PUSH_RESULT
is Node.Instr.SetLocal -> POP_PARAM
is Node.Instr.TeeLocal -> POP_PARAM + PUSH_RESULT
is Node.Instr.GetGlobal -> POP_THIS + PUSH_RESULT
is Node.Instr.SetGlobal -> POP_THIS + POP_PARAM
is Node.Instr.GetGlobal -> PUSH_RESULT
is Node.Instr.SetGlobal -> POP_PARAM
is Node.Instr.I32Load, is Node.Instr.I64Load, is Node.Instr.F32Load, is Node.Instr.F64Load,
is Node.Instr.I32Load8S, is Node.Instr.I32Load8U, is Node.Instr.I32Load16U, is Node.Instr.I32Load16S,
is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U, is Node.Instr.I64Load16S,
is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32 -> POP_PARAM
is Node.Instr.CurrentMemory -> PUSH_RESULT
is Node.Instr.GrowMemory -> POP_PARAM + PUSH_RESULT
is Node.Instr.I64Store32 -> POP_PARAM + POP_PARAM
is Node.Instr.MemorySize -> PUSH_RESULT
is Node.Instr.MemoryGrow -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32Const, is Node.Instr.I64Const,
is Node.Instr.F32Const, is Node.Instr.F64Const -> PUSH_RESULT
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
@ -229,16 +299,16 @@ open class InsnReworker {
is Node.Instr.I64Eqz -> POP_PARAM + PUSH_RESULT
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge, is Node.Instr.F32Sqrt, is Node.Instr.F32Min,
is Node.Instr.F32Max -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32CopySign, is Node.Instr.F32Ceil,
is Node.Instr.F32Floor, is Node.Instr.F32Trunc, is Node.Instr.F32Nearest -> POP_PARAM + PUSH_RESULT
is Node.Instr.F32Gt, is Node.Instr.F32Ge, is Node.Instr.F32Min,
is Node.Instr.F32Max, is Node.Instr.F32CopySign -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor,
is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> POP_PARAM + PUSH_RESULT
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge, is Node.Instr.F64Sqrt, is Node.Instr.F64Min,
is Node.Instr.F64Max -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64CopySign, is Node.Instr.F64Ceil,
is Node.Instr.F64Floor, is Node.Instr.F64Trunc, is Node.Instr.F64Nearest -> POP_PARAM + PUSH_RESULT
is Node.Instr.F64Gt, is Node.Instr.F64Ge, is Node.Instr.F64Min,
is Node.Instr.F64Max, is Node.Instr.F64CopySign -> POP_PARAM + POP_PARAM + PUSH_RESULT
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor,
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32WrapI64, is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64, is Node.Instr.I64ExtendSI32,
is Node.Instr.I64ExtendUI32, is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32,
@ -250,23 +320,23 @@ open class InsnReworker {
is Node.Instr.F64ReinterpretI64 -> POP_PARAM + PUSH_RESULT
}
fun nonAdjacentMemAccesses(insns: List<Insn>) = insns.fold(0 to false) { (count, lastCouldHaveMem), insn ->
/** Returns number of memory accesses. */
fun nonAdjacentMemAccesses(insns: List<Insn>): Int = insns.fold(0 to false) { (count, lastCouldHaveMem), insn ->
val inc =
if (lastCouldHaveMem) 0
else if (insn == Insn.MemNeededOnStack) 1
else if (insn is Insn.Node && insn.insn is Node.Instr.CurrentMemory) 1
else if (insn is Insn.Node && insn.insn is Node.Instr.MemorySize) 1
else 0
val couldSetMemNext = if (insn !is Insn.Node) false else when (insn.insn) {
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32, is Node.Instr.GrowMemory -> true
is Node.Instr.I64Store32, is Node.Instr.MemoryGrow -> true
else -> false
}
(count + inc) to couldSetMemNext
}.let { (count, _) -> count }
companion object : InsnReworker() {
const val POP_THIS = -1
const val POP_PARAM = -1
const val PUSH_RESULT = 1
const val NOP = 0

View File

@ -0,0 +1,244 @@
package asmble.compile.jvm
import asmble.annotation.WasmExport
import asmble.annotation.WasmExternalKind
import asmble.annotation.WasmImport
import asmble.annotation.WasmModule
import org.objectweb.asm.Handle
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.tree.*
import java.lang.invoke.MethodHandle
open class Linker {
fun link(ctx: Context) {
// Quick check to prevent duplicate names
ctx.classes.groupBy { it.name }.values.forEach {
require(it.size == 1) { "Duplicate module name: ${it.first().name}"}
}
// Common items
ctx.cls.superName = Object::class.ref.asmName
ctx.cls.version = Opcodes.V1_8
ctx.cls.access += Opcodes.ACC_PUBLIC
addConstructor(ctx)
addDefaultMaxMemField(ctx)
// Go over each module and add its creation and instance methods
ctx.classes.forEach {
addCreationMethod(ctx, it)
addInstanceField(ctx, it)
addInstanceMethod(ctx, it)
}
TODO()
}
fun addConstructor(ctx: Context) {
// Just the default empty constructor
ctx.cls.methods.plusAssign(
Func(
access = Opcodes.ACC_PUBLIC,
name = "<init>",
params = emptyList(),
ret = Void::class.ref,
insns = listOf(
VarInsnNode(Opcodes.ALOAD, 0),
MethodInsnNode(Opcodes.INVOKESPECIAL, Object::class.ref.asmName, "<init>", "()V", false),
InsnNode(Opcodes.RETURN)
)
).toMethodNode()
)
}
fun addDefaultMaxMemField(ctx: Context) {
(Int.MAX_VALUE / Mem.PAGE_SIZE).let { maxAllowed ->
require(ctx.defaultMaxMemPages <= maxAllowed) {
"Page size ${ctx.defaultMaxMemPages} over max allowed $maxAllowed"
}
}
ctx.cls.fields.plusAssign(FieldNode(
// Make it volatile since it will be publicly mutable
Opcodes.ACC_PUBLIC + Opcodes.ACC_VOLATILE,
"defaultMaxMem",
"I",
null,
ctx.defaultMaxMemPages * Mem.PAGE_SIZE
))
}
fun addCreationMethod(ctx: Context, mod: ModuleClass) {
// The creation method accepts everything needed for import in order of
// imports. For creating a mod w/ self-built memory, we use a default max
// mem field on the linkage class if there isn't a default already.
val params = mod.importClasses(ctx)
var func = Func(
access = Opcodes.ACC_PROTECTED,
name = "create" + mod.name.javaIdent.capitalize(),
params = params.map(ModuleClass::ref),
ret = mod.ref
)
// The stack here on out is for building params to constructor...
// The constructor we'll use is:
// * Mem-class based constructor if it's an import
// * Max-mem int based constructor if mem is self-built and doesn't have a no-mem-no-max ctr
// * Should be only single constructor with imports when there's no mem
val memClassCtr = mod.cls.constructors.find { it.parameters.firstOrNull()?.type?.ref == ctx.mem.memType }
val constructor = if (memClassCtr == null) mod.cls.constructors.singleOrNull() else {
// Use the import annotated one if there
if (memClassCtr.parameters.first().isAnnotationPresent(WasmImport::class.java)) memClassCtr else {
// If there is a non-int-starting constructor, we want to use that
val nonMaxMemCtr = mod.cls.constructors.find {
it != memClassCtr && it.parameters.firstOrNull()?.type != Integer.TYPE
}
if (nonMaxMemCtr != null) nonMaxMemCtr else {
// Use the max-mem constructor and put the int on the stack
func = func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.name, "defaultMaxMem", "I")
)
mod.cls.constructors.find { it.parameters.firstOrNull()?.type != Integer.TYPE }
}
}
}
if (constructor == null) error("Unable to find suitable constructor for ${mod.cls}")
// Now just go over the imports and put them on the stack
func = constructor.parameters.fold(func) { func, param ->
param.getAnnotation(WasmImport::class.java).let { import ->
when (import.kind) {
// Invoke the mem handle to get the mem
// TODO: for imported memory, fail if import.limit < limits.init * page size at runtime
// TODO: for imported memory, fail if import.cap > limits.max * page size at runtime
WasmExternalKind.MEMORY -> func.addInsns(
VarInsnNode(Opcodes.ALOAD, 1 + params.indexOfFirst { it.name == import.module }),
ctx.resolveImportHandle(import).let { memGet ->
MethodInsnNode(Opcodes.INVOKEVIRTUAL, memGet.owner, memGet.name, memGet.desc, false)
}
)
// Bind the method
WasmExternalKind.FUNCTION -> func.addInsns(
LdcInsnNode(ctx.resolveImportHandle(import)),
VarInsnNode(Opcodes.ALOAD, 1 + params.indexOfFirst { it.name == import.module }),
MethodHandle::bindTo.invokeVirtual()
)
// Bind the getter
WasmExternalKind.GLOBAL -> func.addInsns(
LdcInsnNode(ctx.resolveImportHandle(import)),
VarInsnNode(Opcodes.ALOAD, 1 + params.indexOfFirst { it.name == import.module }),
MethodHandle::bindTo.invokeVirtual()
)
// Invoke to get handle array
// TODO: for imported table, fail if import.size < limits.init * page size at runtime
// TODO: for imported table, fail if import.size > limits.max * page size at runtime
WasmExternalKind.TABLE -> func.addInsns(
VarInsnNode(Opcodes.ALOAD, 1 + params.indexOfFirst { it.name == import.module }),
ctx.resolveImportHandle(import).let { tblGet ->
MethodInsnNode(Opcodes.INVOKEVIRTUAL, tblGet.owner, tblGet.name, tblGet.desc, false)
}
)
}
}
}
// Now with all items on the stack we can instantiate and return
func = func.addInsns(
TypeInsnNode(Opcodes.NEW, mod.ref.asmName),
InsnNode(Opcodes.DUP),
MethodInsnNode(
Opcodes.INVOKESPECIAL,
mod.ref.asmName,
"<init>",
constructor.ref.asmDesc,
false
),
InsnNode(Opcodes.ARETURN)
)
ctx.cls.methods.plusAssign(func.toMethodNode())
}
fun addInstanceField(ctx: Context, mod: ModuleClass) {
// Simple protected field that is lazily populated (but doesn't need to be volatile)
ctx.cls.fields.plusAssign(
FieldNode(Opcodes.ACC_PROTECTED, "instance" + mod.name.javaIdent.capitalize(),
mod.ref.asmDesc, null, null)
)
}
fun addInstanceMethod(ctx: Context, mod: ModuleClass) {
// The instance method accepts no parameters. It lazily populates a field by calling the
// creation method. The parameters for the creation method are the imports that are
// accessed via their instance methods. The entire method is synchronized as that is the
// most straightforward way to thread-safely lock the lazy population for now.
val params = mod.importClasses(ctx)
var func = Func(
access = Opcodes.ACC_PUBLIC + Opcodes.ACC_SYNCHRONIZED,
name = mod.name.javaIdent,
ret = mod.ref
)
val alreadyThereLabel = LabelNode()
func = func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.name,
"instance" + mod.name.javaIdent.capitalize(), mod.ref.asmDesc),
JumpInsnNode(Opcodes.IFNONNULL, alreadyThereLabel),
VarInsnNode(Opcodes.ALOAD, 0)
)
func = params.fold(func) { func, importMod ->
func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, importMod.ref.asmName,
importMod.name.javaIdent, importMod.ref.asMethodRetDesc(), false)
)
}
func = func.addInsns(
FieldInsnNode(Opcodes.PUTFIELD, ctx.cls.name,
"instance" + mod.name.javaIdent.capitalize(), mod.ref.asmDesc),
alreadyThereLabel,
VarInsnNode(Opcodes.ALOAD, 0),
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.name,
"instance" + mod.name.javaIdent.capitalize(), mod.ref.asmDesc),
InsnNode(Opcodes.ARETURN)
)
ctx.cls.methods.plusAssign(func)
}
class ModuleClass(val cls: Class<*>, overrideName: String? = null) {
val name = overrideName ?:
cls.getDeclaredAnnotation(WasmModule::class.java)?.name ?: error("No module name available for class $cls")
val ref = TypeRef(Type.getType(cls))
fun importClasses(ctx: Context): List<ModuleClass> {
// Try to find constructor with mem class first, otherwise there should be only one
val constructorWithImports = cls.constructors.find {
it.parameters.firstOrNull()?.type?.ref == ctx.mem.memType
} ?: cls.constructors.singleOrNull() ?: error("Unable to find suitable constructor for $cls")
return constructorWithImports.parameters.toList().mapNotNull {
it.getAnnotation(WasmImport::class.java)?.module
}.distinct().map(ctx::namedModuleClass)
}
}
data class Context(
val classes: List<ModuleClass>,
val className: String,
val cls: ClassNode = ClassNode().also { it.name = className.replace('.', '/') },
val mem: Mem = ByteBufferMem,
val defaultMaxMemPages: Int = 10
) {
fun namedModuleClass(name: String) = classes.find { it.name == name } ?: error("No module named '$name'")
fun resolveImportMethod(import: WasmImport) =
namedModuleClass(import.module).cls.methods.find { method ->
method.getAnnotation(WasmExport::class.java)?.value == import.field &&
method.ref.asmDesc == import.desc
} ?: error("Unable to find export named '${import.field}' in module '${import.module}'")
fun resolveImportHandle(import: WasmImport) = resolveImportMethod(import).let { method ->
Handle(Opcodes.INVOKEVIRTUAL, method.declaringClass.ref.asmName, method.name, method.ref.asmDesc, false)
}
}
companion object : Linker()
}

View File

@ -35,10 +35,13 @@ interface Mem {
fun loadOp(ctx: FuncContext, func: Func, insn: Node.Instr.Args.AlignOffset): Func
// Caller can trust the mem instance is on the stack followed
// by the value. If it's already there after call anyways, this can
// leave the mem inst on the stack and it will be reused or popped.
// by the value. If storeLeavesMemOnStack is true, this should leave the mem
// on the stack after every call.
fun storeOp(ctx: FuncContext, func: Func, insn: Node.Instr.Args.AlignOffset): Func
// Whether or not storeOp leaves a mem instance on the stack
val storeLeavesMemOnStack: Boolean
companion object {
const val PAGE_SIZE = 65536
}

View File

@ -2,14 +2,25 @@ package asmble.compile.jvm
import org.objectweb.asm.Type
/**
* A Java field or method type. This class can be used to make it easier to
* manipulate type and method descriptors.
*
* @param asm Wrapped [org.objectweb.asm.Type] from asm library
*/
data class TypeRef(val asm: Type) {
/** The internal name of the class corresponding to this object or array type. */
val asmName: String get() = asm.internalName
/** The descriptor corresponding to this Java type. */
val asmDesc: String get() = asm.descriptor
fun asMethodRetDesc(vararg args: TypeRef) = Type.getMethodDescriptor(asm, *args.map { it.asm }.toTypedArray())
/** Size of this type in stack, either 1 or 2 only allowed, where 1 = 2^32` bits */
val stackSize: Int get() = if (asm == Type.DOUBLE_TYPE || asm == Type.LONG_TYPE) 2 else 1
fun asMethodRetDesc(vararg args: TypeRef) = Type.getMethodDescriptor(asm, *args.map { it.asm }.toTypedArray())
fun equivalentTo(other: TypeRef) = this == other || this == Unknown || other == Unknown
object UnknownType

View File

@ -5,6 +5,7 @@ import asmble.util.toRawIntBits
import asmble.util.toRawLongBits
import asmble.util.toUnsignedBigInt
import asmble.util.toUnsignedLong
import java.io.ByteArrayOutputStream
open class AstToBinary(val version: Long = 1L) {
@ -140,6 +141,9 @@ open class AstToBinary(val version: Long = 1L) {
fromResizableLimits(b, n.limits)
}
fun fromModule(n: Node.Module) =
ByteArrayOutputStream().also { fromModule(ByteWriter.OutputStream(it), n) }.toByteArray()
fun fromModule(b: ByteWriter, n: Node.Module) {
b.writeUInt32(0x6d736100)
b.writeUInt32(version)
@ -160,10 +164,33 @@ open class AstToBinary(val version: Long = 1L) {
wrapListSection(b, n, 9, n.elems, this::fromElem)
wrapListSection(b, n, 10, n.funcs, this::fromFuncBody)
wrapListSection(b, n, 11, n.data, this::fromData)
n.names?.also { fromNames(b, it) }
// All other custom sections after the previous
n.customSections.filter { it.afterSectionId > 11 }.forEach { fromCustomSection(b, it) }
}
fun fromNames(b: ByteWriter, n: Node.NameSection) {
fun <T> indexMap(b: ByteWriter, map: Map<Int, T>, fn: (T) -> Unit) {
b.writeVarUInt32(map.size)
map.forEach { index, v -> b.writeVarUInt32(index).also { fn(v) } }
}
fun nameMap(b: ByteWriter, map: Map<Int, String>) = indexMap(b, map) { b.writeString(it) }
b.writeVarUInt7(0)
b.withVarUInt32PayloadSizePrepended { b ->
b.writeString("name")
n.moduleName?.also { moduleName ->
b.writeVarUInt7(0)
b.withVarUInt32PayloadSizePrepended { b -> b.writeString(moduleName) }
}
if (n.funcNames.isNotEmpty()) b.writeVarUInt7(1).also {
b.withVarUInt32PayloadSizePrepended { b -> nameMap(b, n.funcNames) }
}
if (n.localNames.isNotEmpty()) b.writeVarUInt7(2).also {
b.withVarUInt32PayloadSizePrepended { b -> indexMap(b, n.localNames) { nameMap(b, it) } }
}
}
}
fun fromResizableLimits(b: ByteWriter, n: Node.ResizableLimits) {
b.writeVarUInt1(n.maximum != null)
b.writeVarUInt32(n.initial)

View File

@ -4,7 +4,7 @@ import asmble.ast.Node
import asmble.ast.SExpr
import asmble.ast.Script
open class AstToSExpr {
open class AstToSExpr(val parensInstrs: Boolean = true) {
fun fromAction(v: Script.Cmd.Action) = when(v) {
is Script.Cmd.Action.Invoke -> newMulti("invoke", v.name) + v.string.quoted + v.exprs.flatMap(this::fromInstrs)
@ -21,11 +21,9 @@ open class AstToSExpr {
newMulti("assert_trap") + fromAction(v.action) + v.failure
is Script.Cmd.Assertion.Malformed -> when (v.module) {
is Script.LazyModule.SExpr -> newMulti("assert_malformed") + v.module.sexpr + v.failure
else -> newMulti("assert_malformed") + fromModule(v.module.value) + v.failure
}
is Script.Cmd.Assertion.Invalid -> when (v.module) {
is Script.LazyModule.SExpr -> newMulti("assert_invalid") + v.module.sexpr + v.failure
else -> newMulti("assert_invalid") + fromModule(v.module.value) + v.failure
}
is Script.Cmd.Assertion.SoftInvalid ->
newMulti("assert_soft_invalid") + fromModule(v.module) + v.failure
@ -47,10 +45,11 @@ open class AstToSExpr {
fun fromData(v: Node.Data) =
(newMulti("data") + v.index) + (newMulti("offset") +
fromInstrs(v.offset)) + v.data.toString(Charsets.UTF_8).quoted
fromInstrs(v.offset).unwrapInstrs()) + v.data.toString(Charsets.UTF_8).quoted
fun fromElem(v: Node.Elem) =
(newMulti("elem") + v.index) + (newMulti("offset") + fromInstrs(v.offset)) + v.funcIndices.map(this::fromNum)
(newMulti("elem") + v.index) + (newMulti("offset") + fromInstrs(v.offset).unwrapInstrs()) +
v.funcIndices.map(this::fromNum)
fun fromElemType(v: Node.ElemType) = when(v) {
Node.ElemType.ANYFUNC -> fromString("anyfunc")
@ -63,25 +62,34 @@ open class AstToSExpr {
Node.ExternalKind.GLOBAL -> newMulti("global") + v.index
}
fun fromFunc(v: Node.Func, name: String? = null, impExp: ImportOrExport? = null) =
newMulti("func", name) + impExp?.let(this::fromImportOrExport) + fromFuncSig(v.type) +
fromLocals(v.locals) + fromInstrs(v.instructions)
fun fromFunc(
v: Node.Func,
name: String? = null,
impExp: ImportOrExport? = null,
localNames: Map<Int, String> = emptyMap()
) =
newMulti("func", name) + impExp?.let(this::fromImportOrExport) + fromFuncSig(v.type, localNames) +
fromLocals(v.locals, v.type.params.size, localNames) + fromInstrs(v.instructions).unwrapInstrs()
fun fromFuncSig(v: Node.Type.Func): List<SExpr> {
fun fromFuncSig(v: Node.Type.Func, localNames: Map<Int, String> = emptyMap()): List<SExpr> {
var ret = emptyList<SExpr>()
if (v.params.isNotEmpty()) ret += newMulti("param") + v.params.map(this::fromType)
if (v.params.isNotEmpty()) {
if (localNames.isEmpty()) ret += newMulti("param") + v.params.map(this::fromType)
else ret += v.params.mapIndexed { index, param -> newMulti("param", localNames[index]) + fromType(param) }
}
v.ret?.also { ret += newMulti("result") + fromType(it) }
return ret
}
fun fromGlobal(v: Node.Global, name: String? = null, impExp: ImportOrExport? = null) =
newMulti("global", name) + impExp?.let(this::fromImportOrExport) + fromGlobalSig(v.type) + fromInstrs(v.init)
newMulti("global", name) + impExp?.let(this::fromImportOrExport) +
fromGlobalSig(v.type) + fromInstrs(v.init).unwrapInstrs()
fun fromGlobalSig(v: Node.Type.Global) =
if (v.mutable) newMulti("mut") + fromType(v.contentType) else fromType(v.contentType)
fun fromImport(v: Node.Import, types: List<Node.Type.Func>) =
(newMulti("import") + v.module.quoted) + v.field.quoted + fromImportKind(v.kind, types)
fun fromImport(v: Node.Import, types: List<Node.Type.Func>, name: String? = null) =
(newMulti("import") + v.module.quoted) + v.field.quoted + fromImportKind(v.kind, types, name)
fun fromImportFunc(v: Node.Import.Kind.Func, types: List<Node.Type.Func>, name: String? = null) =
fromImportFunc(types.getOrElse(v.typeIndex) { throw Exception("No type at ${v.typeIndex}") }, name)
@ -91,19 +99,20 @@ open class AstToSExpr {
fun fromImportGlobal(v: Node.Import.Kind.Global, name: String? = null) =
newMulti("global", name) + fromGlobalSig(v.type)
fun fromImportKind(v: Node.Import.Kind, types: List<Node.Type.Func>) = when(v) {
is Node.Import.Kind.Func -> fromImportFunc(v, types)
is Node.Import.Kind.Table -> fromImportTable(v)
is Node.Import.Kind.Memory -> fromImportMemory(v)
is Node.Import.Kind.Global -> fromImportGlobal(v)
fun fromImportKind(v: Node.Import.Kind, types: List<Node.Type.Func>, name: String? = null) = when(v) {
is Node.Import.Kind.Func -> fromImportFunc(v, types, name)
is Node.Import.Kind.Table -> fromImportTable(v, name)
is Node.Import.Kind.Memory -> fromImportMemory(v, name)
is Node.Import.Kind.Global -> fromImportGlobal(v, name)
}
fun fromImportMemory(v: Node.Import.Kind.Memory, name: String? = null) =
newMulti("memory", name) + fromMemorySig(v.type)
fun fromImportOrExport(v: ImportOrExport) =
if (v.importModule == null) newMulti("export") + v.field
else (newMulti("import") + v.importModule) + v.field
fun fromImportOrExport(v: ImportOrExport) = when (v) {
is ImportOrExport.Import -> listOf((newMulti("import") + v.module) + v.name)
is ImportOrExport.Export -> v.fields.map { newMulti("export") + it }
}
fun fromImportTable(v: Node.Import.Kind.Table, name: String? = null) =
newMulti("table", name) + fromTableSig(v.type)
@ -131,10 +140,39 @@ open class AstToSExpr {
}
}
fun fromInstrs(v: List<Node.Instr>) = v.map(this::fromInstr)
fun fromInstrs(v: List<Node.Instr>): List<SExpr.Multi> {
var index = 0
fun untilNext(vararg insns: Node.Instr): Pair<List<SExpr>, Node.Instr?> {
var ret = emptyList<SExpr>()
while (index < v.size) {
val insn = v[index]
index++
if (insns.contains(insn)) return ret to insn
var expr = fromInstr(insn)
if (insn is Node.Instr.Block || insn is Node.Instr.Loop) {
expr += untilNext(Node.Instr.End).first
if (!parensInstrs) expr += "end"
} else if (insn is Node.Instr.If) untilNext(Node.Instr.Else, Node.Instr.End).let { (subList, insn) ->
var elseList = subList
if (insn is Node.Instr.Else) {
if (parensInstrs) expr += newMulti("then") + subList else expr = expr + subList + "end"
elseList = untilNext(Node.Instr.End).first
}
if (parensInstrs) expr += newMulti("else") + elseList else expr = expr + elseList + "end"
}
if (parensInstrs) ret += expr else ret += expr.vals
}
require(insns.isEmpty()) { "Expected one of ${insns.map { it.op().name }}, got none" }
return ret to null
}
if (parensInstrs) return untilNext().first.map { it as SExpr.Multi }
return listOf(SExpr.Multi(untilNext().first))
}
fun fromLocals(v: List<Node.Type.Value>) =
if (v.isEmpty()) null else newMulti("local") + v.map(this::fromType)
fun fromLocals(v: List<Node.Type.Value>, paramOffset: Int, localNames: Map<Int, String> = emptyMap()) =
if (v.isEmpty()) emptyList()
else if (localNames.isEmpty()) listOf(newMulti("local") + v.map(this::fromType))
else v.mapIndexed { index, v -> newMulti("local", localNames[paramOffset + index]) + fromType(v) }
fun fromMemory(v: Node.Type.Memory, name: String? = null, impExp: ImportOrExport? = null) =
newMulti("memory", name) + impExp?.let(this::fromImportOrExport) + fromMemorySig(v)
@ -147,7 +185,7 @@ open class AstToSExpr {
is Script.Cmd.Meta.Output -> newMulti("output", v.name) + v.str
}
fun fromModule(v: Node.Module, name: String? = null): SExpr.Multi {
fun fromModule(v: Node.Module, name: String? = v.names?.moduleName): SExpr.Multi {
var ret = newMulti("module", name)
// If there is a call_indirect, then we need to output all types in exact order.
@ -159,8 +197,14 @@ open class AstToSExpr {
v.types.filterIndexed { i, _ -> importIndices.contains(i) } - v.funcs.map { it.type }
}
// Keep track of the current function index for names
var funcIndex = -1
ret += types.map { fromTypeDef(it) }
ret += v.imports.map { fromImport(it, v.types) }
ret += v.imports.map {
if (it.kind is Node.Import.Kind.Func) funcIndex++
fromImport(it, v.types, v.names?.funcNames?.get(funcIndex))
}
ret += v.exports.map(this::fromExport)
ret += v.tables.map { fromTable(it) }
ret += v.memories.map { fromMemory(it) }
@ -168,7 +212,14 @@ open class AstToSExpr {
ret += v.elems.map(this::fromElem)
ret += v.data.map(this::fromData)
ret += v.startFuncIndex?.let(this::fromStart)
ret += v.funcs.map { fromFunc(it) }
ret += v.funcs.map {
funcIndex++
fromFunc(
v = it,
name = v.names?.funcNames?.get(funcIndex),
localNames = v.names?.localNames?.get(funcIndex) ?: emptyMap()
)
}
return ret
}
@ -205,12 +256,12 @@ open class AstToSExpr {
if (exp == null) this else this.copy(vals = this.vals + fromString(exp))
private operator fun SExpr.Multi.plus(exp: SExpr?) =
if (exp == null) this else this.copy(vals = this.vals + exp)
private operator fun SExpr.Multi.plus(exps: List<SExpr>) =
if (exps.isEmpty()) this else this.copy(vals = this.vals + exps)
private fun newMulti(initSymb: String? = null, initName: String? = null): SExpr.Multi {
initName?.also { require(it.startsWith("$")) }
return SExpr.Multi() + initSymb + initName
}
private operator fun SExpr.Multi.plus(exps: List<SExpr>?) =
if (exps == null || exps.isEmpty()) this else this.copy(vals = this.vals + exps)
private fun newMulti(initSymb: String? = null, initName: String? = null) =
SExpr.Multi() + initSymb + initName?.let { "$$it" }
private fun List<SExpr.Multi>.unwrapInstrs() =
if (parensInstrs) this else this.single().vals
private val String.quoted get() = fromString(this, true)
companion object : AstToSExpr()

View File

@ -2,10 +2,12 @@ package asmble.io
import asmble.ast.Node
import asmble.util.*
import java.nio.ByteBuffer
open class BinaryToAst(
val version: Long = 1L,
val logger: Logger = Logger.Print(Logger.Level.OFF)
val logger: Logger = Logger.Print(Logger.Level.WARN),
val includeNameSection: Boolean = true
) : Logger by logger {
fun toBlockType(b: ByteReader) = b.readVarInt7().toInt().let {
@ -18,6 +20,23 @@ open class BinaryToAst(
payload = b.readBytes()
)
fun toNameSection(b: ByteReader) = generateSequence {
if (b.isEof) null
else b.readVarUInt7().toInt() to b.read(b.readVarUInt32AsInt())
}.fold(Node.NameSection(null, emptyMap(), emptyMap())) { sect, (type, b) ->
fun <T> indexMap(b: ByteReader, fn: (ByteReader) -> T) =
b.readList { it.readVarUInt32AsInt() to fn(it) }.let { pairs ->
pairs.toMap().also { require(it.size == pairs.size) { "Malformed names: duplicate indices" } }
}
fun nameMap(b: ByteReader) = indexMap(b) { it.readString() }
when (type) {
0 -> sect.copy(moduleName = b.readString())
1 -> sect.copy(funcNames = nameMap(b))
2 -> sect.copy(localNames = indexMap(b, ::nameMap))
else -> error("Malformed names: unrecognized type: $type")
}.also { require(b.isEof) }
}
fun toData(b: ByteReader) = Node.Data(
index = b.readVarUInt32AsInt(),
offset = toInitExpr(b),
@ -112,7 +131,7 @@ open class BinaryToAst(
op.create(b.readVarUInt32AsInt())
is Node.InstrOp.CallOp.IndexReservedArg -> op.create(
b.readVarUInt32AsInt(),
b.readVarUInt1()
b.readVarUInt1().also { if (it) throw IoErr.InvalidReservedArg() }
)
is Node.InstrOp.ParamOp.NoArg ->
op.create
@ -123,7 +142,7 @@ open class BinaryToAst(
b.readVarUInt32()
)
is Node.InstrOp.MemOp.ReservedArg ->
op.create(b.readVarUInt1())
op.create(b.readVarUInt1().also { if (it) throw IoErr.InvalidReservedArg() })
is Node.InstrOp.ConstOp.IntArg ->
op.create(b.readVarInt32())
is Node.InstrOp.ConstOp.LongArg ->
@ -143,12 +162,15 @@ open class BinaryToAst(
}
}
fun toLocals(b: ByteReader) = b.readVarUInt32AsInt().let { size ->
toValueType(b).let { type -> List(size) { type } }
fun toLocals(b: ByteReader): List<Node.Type.Value> {
val size = try { b.readVarUInt32AsInt() } catch (e: NumberFormatException) { throw IoErr.InvalidLocalSize(e) }
return toValueType(b).let { type -> List(size) { type } }
}
fun toMemoryType(b: ByteReader) = Node.Type.Memory(toResizableLimits(b))
fun toModule(b: ByteArray) = toModule(ByteReader.InputStream(b.inputStream()))
fun toModule(b: ByteReader): Node.Module {
if (b.readUInt32() != 0x6d736100L) throw IoErr.InvalidMagicNumber()
b.readUInt32().let { if (it != version) throw IoErr.InvalidVersion(it, listOf(version)) }
@ -163,14 +185,17 @@ open class BinaryToAst(
require(sectionId > maxSectionId) { "Section ID $sectionId came after $maxSectionId" }.
also { maxSectionId = sectionId }
val sectionLen = b.readVarUInt32AsInt()
// each 'read' invocation creates new InputStream and don't closes it
sections += sectionId to b.read(sectionLen)
}
// Now build the module
fun <T> readSectionList(sectionId: Int, fn: (ByteReader) -> T) =
sections.find { it.first == sectionId }?.second?.readList(fn) ?: emptyList()
val types = readSectionList(1, this::toFuncType)
val funcIndices = readSectionList(3) { it.readVarUInt32AsInt() }
var nameSection: Node.NameSection? = null
return Node.Module(
types = types,
imports = readSectionList(2, this::toImport),
@ -191,10 +216,18 @@ open class BinaryToAst(
val afterSectionId = if (index == 0) 0 else sections[index - 1].let { (prevSectionId, _) ->
if (prevSectionId == 0) customSections.last().afterSectionId else prevSectionId
}
customSections + toCustomSection(b, afterSectionId)
// Try to parse the name section
val section = toCustomSection(b, afterSectionId).takeIf { section ->
val shouldParseNames = includeNameSection && nameSection == null && section.name == "name"
!shouldParseNames || try {
nameSection = toNameSection(ByteReader.InputStream(section.payload.inputStream()))
false
} catch (e: Exception) { warn { "Failed parsing name section: $e" }; true }
}
if (section == null) customSections else customSections + section
}
}
)
).copy(names = nameSection)
}
fun toResizableLimits(b: ByteReader) = b.readVarUInt1().let {
@ -215,7 +248,10 @@ open class BinaryToAst(
else -> error("Unknown value type: $type")
}
fun ByteReader.readString() = this.readVarUInt32AsInt().let { String(this.readBytes(it)) }
fun ByteReader.readString() = this.readVarUInt32AsInt().let {
// We have to use the decoder directly to get malformed-input errors
Charsets.UTF_8.newDecoder().decode(ByteBuffer.wrap(this.readBytes(it))).toString()
}
fun <T> ByteReader.readList(fn: (ByteReader) -> T) = this.readVarUInt32().let { listSize ->
(0 until listSize).map { fn(this) }
}

View File

@ -1,12 +1,12 @@
package asmble.io
import asmble.util.toIntExact
import asmble.util.toUnsignedBigInt
import asmble.util.toUnsignedLong
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.math.BigInteger
abstract class ByteReader {
abstract val isEof: Boolean
@ -34,27 +34,30 @@ abstract class ByteReader {
}
fun readVarInt7() = readSignedLeb128().let {
require(it >= Byte.MIN_VALUE.toLong() && it <= Byte.MAX_VALUE.toLong())
if (it < Byte.MIN_VALUE.toLong() || it > Byte.MAX_VALUE.toLong()) throw IoErr.InvalidLeb128Number()
it.toByte()
}
fun readVarInt32() = readSignedLeb128().toIntExact()
fun readVarInt32() = readSignedLeb128().let {
if (it < Int.MIN_VALUE.toLong() || it > Int.MAX_VALUE.toLong()) throw IoErr.InvalidLeb128Number()
it.toInt()
}
fun readVarInt64() = readSignedLeb128()
fun readVarInt64() = readSignedLeb128(9)
fun readVarUInt1() = readUnsignedLeb128().let {
require(it == 1 || it == 0)
if (it != 1 && it != 0) throw IoErr.InvalidLeb128Number()
it == 1
}
fun readVarUInt7() = readUnsignedLeb128().let {
require(it <= 255)
if (it > 255) throw IoErr.InvalidLeb128Number()
it.toShort()
}
fun readVarUInt32() = readUnsignedLeb128().toUnsignedLong()
protected fun readUnsignedLeb128(): Int {
protected fun readUnsignedLeb128(maxCount: Int = 4): Int {
// Taken from Android source, Apache licensed
var result = 0
var cur: Int
@ -63,12 +66,12 @@ abstract class ByteReader {
cur = readByte().toInt() and 0xff
result = result or ((cur and 0x7f) shl (count * 7))
count++
} while (cur and 0x80 == 0x80 && count < 5)
if (cur and 0x80 == 0x80) throw NumberFormatException()
} while (cur and 0x80 == 0x80 && count <= maxCount)
if (cur and 0x80 == 0x80) throw IoErr.InvalidLeb128Number()
return result
}
private fun readSignedLeb128(): Long {
private fun readSignedLeb128(maxCount: Int = 4): Long {
// Taken from Android source, Apache licensed
var result = 0L
var cur: Int
@ -79,12 +82,25 @@ abstract class ByteReader {
result = result or ((cur and 0x7f).toLong() shl (count * 7))
signBits = signBits shl 7
count++
} while (cur and 0x80 == 0x80 && count < 10)
if (cur and 0x80 == 0x80) throw NumberFormatException()
} while (cur and 0x80 == 0x80 && count <= maxCount)
if (cur and 0x80 == 0x80) throw IoErr.InvalidLeb128Number()
// Check for 64 bit invalid, taken from Apache/MIT licensed:
// https://github.com/paritytech/parity-wasm/blob/2650fc14c458c6a252c9dc43dd8e0b14b6d264ff/src/elements/primitives.rs#L351
// TODO: probably need 32 bit checks too, but meh, not in the suite
if (count > maxCount && maxCount == 9) {
if (cur and 0b0100_0000 == 0b0100_0000) {
if ((cur or 0b1000_0000).toByte() != (-1).toByte()) throw IoErr.InvalidLeb128Number()
} else if (cur != 0) {
throw IoErr.InvalidLeb128Number()
}
}
if ((signBits shr 1) and result != 0L) result = result or signBits
return result
}
// todo looks like this InputStream isn't ever closed
class InputStream(val ins: java.io.InputStream) : ByteReader() {
private var nextByte: Byte? = null
private var sawEof = false

View File

@ -4,8 +4,6 @@ import asmble.util.unsignedToSignedInt
import asmble.util.unsignedToSignedLong
import java.io.ByteArrayOutputStream
import java.math.BigInteger
import java.nio.ByteBuffer
import java.nio.ByteOrder
abstract class ByteWriter {
abstract val written: Int

View File

@ -0,0 +1,18 @@
package asmble.io
/*
data class ImportOrExport(val field: String, val importModule: String?) {
}
*/
sealed class ImportOrExport {
abstract val itemCount: Int
data class Import(val name: String, val module: String, val exportFields: List<String>) : ImportOrExport() {
override val itemCount get() = 1 + exportFields.size
}
data class Export(val fields: List<String>) : ImportOrExport() {
override val itemCount get() = fields.size
}
}

View File

@ -0,0 +1,130 @@
package asmble.io
import asmble.AsmErr
sealed class IoErr(message: String, cause: Throwable? = null) : RuntimeException(message, cause), AsmErr {
class UnexpectedEnd : IoErr("Unexpected EOF") {
override val asmErrString get() = "unexpected end"
override val asmErrStrings get() = listOf(asmErrString, "length out of bounds")
}
class InvalidMagicNumber : IoErr("Invalid magic number") {
override val asmErrString get() = "magic header not detected"
}
class InvalidVersion(actual: Long, expected: List<Long>) : IoErr("Got version $actual, only accepts $expected") {
override val asmErrString get() = "unknown binary version"
}
class InvalidSectionId(id: Int) : IoErr("Invalid section ID of $id") {
override val asmErrString get() = "invalid section id"
// Since we test section length before section content, we get a different error than the spec
override val asmErrStrings get() = listOf(asmErrString, "invalid mutability")
}
class InvalidCodeLength(funcLen: Int, codeLen: Int) : IoErr("Got $funcLen funcs but only $codeLen bodies") {
override val asmErrString get() = "function and code section have inconsistent lengths"
}
class InvalidMutability : IoErr("Invalid mutability boolean") {
override val asmErrString get() = "invalid mutability"
}
class InvalidReservedArg : IoErr("Invalid reserved arg") {
override val asmErrString get() = "zero flag expected"
}
class MultipleMemories : IoErr("Only single memory allowed") {
override val asmErrString get() = "multiple memories"
}
class MultipleTables : IoErr("Only single table allowed") {
override val asmErrString get() = "multiple tables"
}
class MemoryInitMaxMismatch(val init: Int, val max: Int) : IoErr("Memory init $init is over max $max") {
override val asmErrString get() = "memory size minimum must not be greater than maximum"
}
class MemorySizeOverflow(val given: Long) : IoErr("Memory $given cannot exceed 65536 (4GiB)") {
override val asmErrString get() = "memory size must be at most 65536 pages (4GiB)"
}
class InvalidAlignPower(val align: Int) : IoErr("Alignment expected to be positive power of 2, but got $align") {
override val asmErrString get() = "alignment must be positive power of 2"
}
class InvalidAlignTooLarge(val align: Int, val allowed: Int) : IoErr("Alignment $align larger than $allowed") {
override val asmErrString get() = "alignment must not be larger than natural"
}
class InvalidResultArity : IoErr("Only single results supported") {
override val asmErrString get() = "invalid result arity"
}
class UnknownType(val index: Int) : IoErr("No type present for index $index") {
override val asmErrString get() = "unknown type"
}
class InvalidType(val str: String) : IoErr("Invalid type: $str") {
override val asmErrString get() = "unexpected token"
}
class MismatchLabelEnd(val expected: String?, val actual: String) :
IoErr("Expected end for $expected, got $actual") {
override val asmErrString get() = "mismatching label"
}
class ConstantOutOfRange(val actual: Number) : IoErr("Constant out of range: $actual") {
override val asmErrString get() = "constant out of range"
}
class ConstantUnknownOperator(val str: String) : IoErr("Unknown constant operator for: $str") {
override val asmErrString get() = "unknown operator"
}
class FuncTypeRefMismatch : IoErr("Func type for type ref doesn't match explicit params/returns") {
override val asmErrString get() = "inline function type"
override val asmErrStrings get() = listOf(asmErrString, "unexpected token")
}
class UnrecognizedInstruction(val instr: String) : IoErr("Unrecognized instruction: $instr") {
override val asmErrString get() = "unexpected token"
override val asmErrStrings get() = listOf(asmErrString, "unknown operator")
}
class ImportAfterNonImport(val nonImportType: String) : IoErr("Import happened after $nonImportType") {
override val asmErrString get() = "import after $nonImportType"
}
class UnknownOperator : IoErr("Unknown operator") {
override val asmErrString get() = "unknown operator"
override val asmErrStrings get() = listOf(asmErrString, "unexpected token")
}
class InvalidVar(val found: String) : IoErr("Var ref expected, found: $found") {
override val asmErrString get() = "unknown operator"
}
class ResultBeforeParameter : IoErr("Function result before parameter") {
override val asmErrString get() = "result before parameter"
override val asmErrStrings get() = listOf(asmErrString, "unexpected token")
}
class IndirectCallSetParamNames : IoErr("Indirect call tried to set name to param in func type") {
override val asmErrString get() = "unexpected token"
}
class InvalidUtf8Encoding : IoErr("Some byte sequence was not UTF-8 compatible") {
override val asmErrString get() = "invalid UTF-8 encoding"
}
class InvalidLeb128Number : IoErr("Invalid LEB128 number") {
override val asmErrString get() = "integer representation too long"
override val asmErrStrings get() = listOf(asmErrString, "integer too large")
}
class InvalidLocalSize(cause: NumberFormatException) : IoErr("Invalid local size", cause) {
override val asmErrString get() = "too many locals"
}
}

View File

@ -9,10 +9,28 @@ import asmble.util.*
import java.io.ByteArrayInputStream
import java.math.BigInteger
typealias NameMap = Map<String, Int>
open class SExprToAst(
val includeNames: Boolean = true
) {
data class ExprContext(
val nameMap: NameMap,
val blockDepth: Int = 0,
val types: List<Node.Type.Func> = emptyList(),
val callIndirectNeverBeforeSeenFuncTypes: MutableList<Node.Type.Func> = mutableListOf()
) {
companion object {
val empty = ExprContext(NameMap(emptyMap(), null, null))
}
}
open class SExprToAst {
data class ExprContext(val nameMap: NameMap, val blockDepth: Int = 0)
data class FuncResult(
val name: String?,
val func: Node.Func,
val importOrExport: ImportOrExport?,
// These come from call_indirect insns
val additionalFuncTypesToAdd: List<Node.Type.Func>,
val nameMap: NameMap
)
fun toAction(exp: SExpr.Multi): Script.Cmd.Action {
var index = 1
@ -23,7 +41,7 @@ open class SExprToAst {
return when(exp.vals.first().symbolStr()) {
"invoke" ->
Script.Cmd.Action.Invoke(name, str, exp.vals.drop(index).map {
toExprMaybe(it as SExpr.Multi, ExprContext(emptyMap()))
toExprMaybe(it as SExpr.Multi, ExprContext.empty)
})
"get" ->
Script.Cmd.Action.Get(name, str)
@ -36,7 +54,7 @@ open class SExprToAst {
return when(exp.vals.first().symbolStr()) {
"assert_return" ->
Script.Cmd.Assertion.Return(toAction(mult),
exp.vals.drop(2).map { toExprMaybe(it as SExpr.Multi, ExprContext(emptyMap())) })
exp.vals.drop(2).map { toExprMaybe(it as SExpr.Multi, ExprContext.empty) })
"assert_return_canonical_nan" ->
Script.Cmd.Assertion.ReturnNan(toAction(mult), canonical = true)
"assert_return_arithmetic_nan" ->
@ -66,13 +84,15 @@ open class SExprToAst {
}
fun toBlockSigMaybe(exp: SExpr.Multi, offset: Int): List<Node.Type.Value> {
val types = exp.vals.drop(offset).takeUntilNullLazy { if (it is SExpr.Symbol) toTypeMaybe(it) else null }
val multi = exp.vals.getOrNull(offset) as? SExpr.Multi
if (multi == null || multi.vals.firstOrNull()?.symbolStr() != "result") return emptyList()
val types = multi.vals.drop(1).map { it.symbol()?.let { toTypeMaybe(it) } ?: error("Unknown type on $it") }
// We can only handle one type for now
require(types.size <= 1)
return types
}
fun toCmd(exp: SExpr.Multi): Script.Cmd {
fun toCmdMaybe(exp: SExpr.Multi): Script.Cmd? {
val expName = exp.vals.first().symbolStr()
return when(expName) {
"module" ->
@ -87,7 +107,7 @@ open class SExprToAst {
"script", "input", "output" ->
toMeta(exp)
else ->
error("Unrecognized cmd expr '$expName'")
null
}
}
@ -133,7 +153,7 @@ open class SExprToAst {
fun toExport(exp: SExpr.Multi, nameMap: NameMap): Node.Export {
exp.requireFirstSymbol("export")
val field = exp.vals[1].symbolStr()!!
val field = exp.vals[1].symbolUtf8Str()!!
val kind = exp.vals[2] as SExpr.Multi
val extKind = when(kind.vals[0].symbolStr()) {
"func" -> Node.ExternalKind.FUNCTION
@ -152,7 +172,7 @@ open class SExprToAst {
if (maybeOpAndOffset != null) {
// Everything left in the multi should be a a multi expression
return exp.vals.drop(maybeOpAndOffset.second).flatMap {
toExprMaybe(it as SExpr.Multi, ctx)
toExprMaybe(it as SExpr.Multi, ctx).also { if (it.isEmpty()) throw IoErr.UnknownOperator() }
} + maybeOpAndOffset.first
}
// Other blocks take up the rest (ignore names)
@ -161,7 +181,7 @@ open class SExprToAst {
var innerCtx = ctx.copy(blockDepth = ctx.blockDepth + 1)
exp.maybeName(opOffset)?.also {
opOffset++
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap + ("block:$it" to innerCtx.blockDepth))
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap.add("block", it, innerCtx.blockDepth))
}
val sigs = toBlockSigMaybe(exp, opOffset)
@ -208,24 +228,36 @@ open class SExprToAst {
exp: SExpr.Multi,
origNameMap: NameMap,
types: List<Node.Type.Func>
): Triple<String?, Node.Func, ImportOrExport?> {
): FuncResult {
exp.requireFirstSymbol("func")
var currentIndex = 1
val name = exp.maybeName(currentIndex)
if (name != null) currentIndex++
val maybeImpExp = toImportOrExportMaybe(exp, currentIndex)
if (maybeImpExp != null) currentIndex++
maybeImpExp?.also { currentIndex += it.itemCount }
var (nameMap, exprsUsed, sig) = toFuncSig(exp, currentIndex, origNameMap, types)
currentIndex += exprsUsed
val locals = exp.repeated("local", currentIndex, { toLocals(it) }).mapIndexed { index, (nameMaybe, vals) ->
nameMaybe?.also { require(vals.size == 1); nameMap += "local:$it" to (index + sig.params.size) }
nameMaybe?.also { require(vals.size == 1); nameMap = nameMap.add("local", it, index + sig.params.size) }
vals
}
currentIndex += locals.size
val (instrs, _) = toInstrs(exp, currentIndex, ExprContext(nameMap))
// We create a context for insn parsing (it even has sa mutable var)
val ctx = ExprContext(
nameMap = nameMap,
// We add ourselves to the context type list if we're not there to keep the indices right
types = if (types.contains(sig)) types else types + sig
)
val (instrs, _) = toInstrs(exp, currentIndex, ctx)
// Imports can't have locals or instructions
if (maybeImpExp?.importModule != null) require(locals.isEmpty() && instrs.isEmpty())
return Triple(name, Node.Func(sig, locals.flatten(), instrs), maybeImpExp)
if (maybeImpExp is ImportOrExport.Import) require(locals.isEmpty() && instrs.isEmpty())
return FuncResult(
name = name,
func = Node.Func(sig, locals.flatten(), instrs),
importOrExport = maybeImpExp,
additionalFuncTypesToAdd = ctx.callIndirectNeverBeforeSeenFuncTypes,
nameMap = nameMap
)
}
fun toFuncSig(
@ -242,19 +274,25 @@ open class SExprToAst {
} else null to offset
var nameMap = origNameMap
val params = exp.repeated("param", offset, { toParams(it) }).mapIndexed { index, (nameMaybe, vals) ->
nameMaybe?.also { require(vals.size == 1); nameMap += "local:$it" to index }
nameMaybe?.also { require(vals.size == 1); nameMap = nameMap.add("local", it, index) }
vals
}
val results = exp.repeated("result", offset + params.size, this::toResult)
val resultExps = exp.repeated("result", offset + params.size, this::toResult)
val results = resultExps.flatten()
if (results.size > 1) throw IoErr.InvalidResultArity()
val usedExps = params.size + results.size + if (typeRef == null) 0 else 1
val usedExps = params.size + resultExps.size + if (typeRef == null) 0 else 1
// Make sure there aren't parameters following the result
if (resultExps.isNotEmpty() && (exp.vals.getOrNull(offset + params.size + resultExps.size) as? SExpr.Multi)?.
vals?.firstOrNull()?.symbolStr() == "param") {
throw IoErr.ResultBeforeParameter()
}
// Check against type ref
if (typeRef != null) {
// No params or results means just use it
if (params.isEmpty() && results.isEmpty()) return Triple(nameMap, usedExps, typeRef)
// Otherwise, just make sure it matches
require(typeRef.params == params.flatten() && typeRef.ret == results.firstOrNull()) {
"Params for type ref do not match explicit ones"
if (typeRef.params != params.flatten() || typeRef.ret != results.firstOrNull()) {
throw IoErr.FuncTypeRefMismatch()
}
}
return Triple(nameMap, usedExps, Node.Type.Func(params.flatten(), results.firstOrNull()))
@ -266,12 +304,12 @@ open class SExprToAst {
val name = exp.maybeName(currIndex)
if (name != null) currIndex++
val maybeImpExp = toImportOrExportMaybe(exp, currIndex)
if (maybeImpExp != null) currIndex++
maybeImpExp?.also { currIndex += it.itemCount }
val sig = toGlobalSig(exp.vals[currIndex])
currIndex++
val (instrs, _) = toInstrs(exp, currIndex, ExprContext(nameMap))
// Imports can't have instructions
require((maybeImpExp?.importModule != null) == instrs.isEmpty())
if (maybeImpExp is ImportOrExport.Import) require(instrs.isEmpty())
return Triple(name, Node.Global(sig, instrs), maybeImpExp)
}
@ -283,15 +321,19 @@ open class SExprToAst {
}
}
fun toImport(exp: SExpr.Multi): Triple<String, String, Node.Type> {
fun toImport(
exp: SExpr.Multi,
origNameMap: NameMap,
types: List<Node.Type.Func>
): Triple<String, String, Node.Type> {
exp.requireFirstSymbol("import")
val module = exp.vals[1].symbolStr()!!
val field = exp.vals[2].symbolStr()!!
val module = exp.vals[1].symbolUtf8Str()!!
val field = exp.vals[2].symbolUtf8Str()!!
val kind = exp.vals[3] as SExpr.Multi
val kindName = kind.vals.firstOrNull()?.symbolStr()
val kindSubOffset = if (kind.maybeName(1) == null) 1 else 2
return Triple(module, field, when(kindName) {
"func" -> toFuncSig(kind, kindSubOffset, emptyMap(), emptyList()).third
"func" -> toFuncSig(kind, kindSubOffset, origNameMap, types).third
"global" -> toGlobalSig(kind.vals[kindSubOffset])
"table" -> toTableSig(kind, kindSubOffset)
"memory" -> toMemorySig(kind, kindSubOffset)
@ -301,12 +343,23 @@ open class SExprToAst {
fun toImportOrExportMaybe(exp: SExpr.Multi, offset: Int): ImportOrExport? {
if (offset >= exp.vals.size) return null
val multi = exp.vals[offset] as? SExpr.Multi ?: return null
val multiHead = multi.vals[0] as? SExpr.Symbol ?: return null
return when (multiHead.contents) {
"export" -> ImportOrExport(multi.vals[1].symbolStr()!!, null)
"import" -> ImportOrExport(multi.vals[2].symbolStr()!!, multi.vals[1].symbolStr()!!)
else -> null
var currOffset = offset
// Get all export fields first
var exportFields = emptyList<String>()
while (true) {
val multi = exp.vals.getOrNull(currOffset) as? SExpr.Multi
when (multi?.vals?.firstOrNull()?.symbolStr()) {
"import" -> return ImportOrExport.Import(
name = multi.vals.getOrNull(2)?.symbolUtf8Str() ?: error("No import name"),
module = multi.vals.getOrNull(1)?.symbolUtf8Str() ?: error("No import module"),
exportFields = exportFields
)
"export" -> multi.vals.getOrNull(1)?.symbolUtf8Str().also {
exportFields += it ?: error("No export field")
}
else -> return if (exportFields.isEmpty()) null else ImportOrExport.Export(exportFields)
}
currOffset++
}
}
@ -324,8 +377,8 @@ open class SExprToAst {
ret += maybeInstrAndOffset.first
runningOffset += maybeInstrAndOffset.second
}
if (mustCompleteExp) require(offset + runningOffset == exp.vals.size) {
"Unrecognized instruction: ${exp.vals[offset + runningOffset]}"
if (mustCompleteExp && offset + runningOffset != exp.vals.size) {
throw IoErr.UnrecognizedInstruction(exp.vals[offset + runningOffset].toString())
}
return Pair(ret, runningOffset)
}
@ -348,7 +401,7 @@ open class SExprToAst {
val maybeName = exp.maybeName(offset + opOffset)
if (maybeName != null) {
opOffset++
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap + ("block:$maybeName" to innerCtx.blockDepth))
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap.add("block", maybeName, innerCtx.blockDepth))
}
val sigs = toBlockSigMaybe(exp, offset + opOffset)
opOffset += sigs.size
@ -381,7 +434,7 @@ open class SExprToAst {
opOffset++
exp.maybeName(offset + opOffset)?.also {
opOffset++
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap + ("block:$it" to ctx.blockDepth))
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap.add("block", it, ctx.blockDepth))
}
toInstrs(exp, offset + opOffset, innerCtx, false).also {
ret += it.first
@ -397,7 +450,7 @@ open class SExprToAst {
opOffset++
exp.maybeName(offset + opOffset)?.also {
opOffset++
require(it == maybeName, { "Expected end for $maybeName, got $it" })
if (it != maybeName) throw IoErr.MismatchLabelEnd(maybeName, it)
}
return Pair(ret, opOffset)
}
@ -415,7 +468,7 @@ open class SExprToAst {
val name = exp.maybeName(currIndex)
if (name != null) currIndex++
val maybeImpExp = toImportOrExportMaybe(exp, currIndex)
if (maybeImpExp != null) currIndex++
maybeImpExp?.also { currIndex += it.itemCount }
// If it's a multi we assume "data", otherwise assume sig
val memOrData = exp.vals[currIndex].let {
when (it) {
@ -449,23 +502,34 @@ open class SExprToAst {
}
fun toModule(exp: SExpr.Multi): Pair<String?, Node.Module> {
// As a special case, if this isn't a "module", wrap it and try again
if (exp.vals.firstOrNull()?.symbolStr() != "module") {
return toModule(SExpr.Multi(listOf(SExpr.Symbol("module")) + exp.vals))
}
exp.requireFirstSymbol("module")
val name = exp.maybeName(1)
// If all of the other symbols after the name are quoted strings,
// this needs to be parsed as a binary
exp.vals.drop(if (name == null) 1 else 2).also { otherVals ->
if (otherVals.isNotEmpty() && otherVals.find { it !is SExpr.Symbol || !it.quoted } == null)
return name to toModuleFromBytes(otherVals.fold(byteArrayOf()) { bytes, strVal ->
bytes + (strVal as SExpr.Symbol).rawContentCharsToBytes()
})
// Special cases for "quote" and "binary" modules.
val quoteOrBinary = exp.vals.elementAtOrNull(if (name == null) 1 else 2)?.
symbolStr()?.takeIf { it == "quote" || it == "binary" }
if (quoteOrBinary != null) {
val bytes = exp.vals.drop(if (name == null) 2 else 3).fold(byteArrayOf()) { bytes, expr ->
bytes + (
expr.symbol()?.takeIf { it.quoted }?.rawContentCharsToBytes() ?: error("Expected quoted string")
)
}
// For binary, just load from bytes
if (quoteOrBinary == "binary") return name to toModuleFromBytes(bytes)
// Otherwise, take the quoted strings and parse em
return toModuleFromQuotedString(bytes.toString(Charsets.US_ASCII))
}
var mod = Node.Module()
val exps = exp.vals.mapNotNull { it as? SExpr.Multi }
// Eagerly build the names (for forward decls)
var nameMap = toModuleForwardNameMap(exps)
var (nameMap, eagerTypes) = toModuleForwardNameMapAndTypes(exps)
mod = mod.copy(types = eagerTypes)
fun Node.Module.addTypeIfNotPresent(type: Node.Type.Func): Pair<Node.Module, Int> {
val index = this.types.indexOf(type)
@ -478,64 +542,79 @@ open class SExprToAst {
var globalCount = 0
var tableCount = 0
var memoryCount = 0
fun handleImport(module: String, field: String, kind: Node.Type) {
fun handleImport(module: String, field: String, kind: Node.Type, exportFields: List<String>) {
// We make sure that an import doesn't happen after a non-import
require(mod.funcs.isEmpty() && mod.globals.isEmpty() &&
mod.tables.isEmpty() && mod.memories.isEmpty()) { "Import happened after non-import" }
val importKind = when(kind) {
if (mod.funcs.isNotEmpty()) throw IoErr.ImportAfterNonImport("function")
if (mod.globals.isNotEmpty()) throw IoErr.ImportAfterNonImport("global")
if (mod.tables.isNotEmpty()) throw IoErr.ImportAfterNonImport("table")
if (mod.memories.isNotEmpty()) throw IoErr.ImportAfterNonImport("memory")
val (importKind, indexAndExtKind) = when(kind) {
is Node.Type.Func -> mod.addTypeIfNotPresent(kind).let { (m, idx) ->
funcCount++
mod = m
Node.Import.Kind.Func(idx)
Node.Import.Kind.Func(idx) to (funcCount++ to Node.ExternalKind.FUNCTION)
}
is Node.Type.Global -> { globalCount++; Node.Import.Kind.Global(kind) }
is Node.Type.Table -> { tableCount++; Node.Import.Kind.Table(kind) }
is Node.Type.Memory -> { memoryCount++; Node.Import.Kind.Memory(kind) }
is Node.Type.Global ->
Node.Import.Kind.Global(kind) to (globalCount++ to Node.ExternalKind.GLOBAL)
is Node.Type.Table ->
Node.Import.Kind.Table(kind) to (tableCount++ to Node.ExternalKind.TABLE)
is Node.Type.Memory ->
Node.Import.Kind.Memory(kind) to (memoryCount++ to Node.ExternalKind.MEMORY)
else -> throw Exception("Unrecognized import kind: $kind")
}
mod = mod.copy(imports = mod.imports + Node.Import(module, field, importKind))
mod = mod.copy(
imports = mod.imports + Node.Import(module, field, importKind),
exports = mod.exports + exportFields.map {
Node.Export(it, indexAndExtKind.second, indexAndExtKind.first)
}
)
}
fun addMaybeExport(impExp: ImportOrExport?, extKind: Node.ExternalKind, index: Int) {
impExp?.also { mod = mod.copy(exports = mod.exports + Node.Export(it.field, extKind, index)) }
fun addExport(exp: ImportOrExport.Export, extKind: Node.ExternalKind, index: Int) {
mod = mod.copy(exports = mod.exports + exp.fields.map { Node.Export(it, extKind, index) })
}
// Now just handle all expressions in order
exps.forEach { exp ->
when(exp.vals.firstOrNull()?.symbolStr()) {
"import" -> toImport(exp).let { (module, field, type) -> handleImport(module, field, type) }
"type" -> toTypeDef(exp, nameMap).let { (name, type) ->
// We always add the type, even if it's a duplicate.
// Ref: https://github.com/WebAssembly/design/issues/1041
if (name != null) nameMap += "type:$name" to mod.types.size
mod = mod.copy(types = mod.types + type)
"import" -> toImport(exp, nameMap, mod.types).let { (module, field, type) ->
handleImport(module, field, type, emptyList())
}
// We do not handle types here anymore. They are handled eagerly as part of the forward pass.
"type" -> { }
"export" -> mod = mod.copy(exports = mod.exports + toExport(exp, nameMap))
"elem" -> mod = mod.copy(elems = mod.elems + toElem(exp, nameMap))
"data" -> mod = mod.copy(data = mod.data + toData(exp, nameMap))
"start" -> mod = mod.copy(startFuncIndex = toStart(exp, nameMap))
"func" -> toFunc(exp, nameMap, mod.types).also { (_, fn, impExp) ->
if (impExp != null && impExp.importModule != null) {
handleImport(impExp.importModule, impExp.field, fn.type)
"func" -> toFunc(exp, nameMap, mod.types).also { (_, fn, impExp, additionalFuncTypes, localNameMap) ->
if (impExp is ImportOrExport.Import) {
handleImport(impExp.module, impExp.name, fn.type, impExp.exportFields)
} else {
addMaybeExport(impExp, Node.ExternalKind.FUNCTION, funcCount++)
if (impExp is ImportOrExport.Export) addExport(impExp, Node.ExternalKind.FUNCTION, funcCount)
if (includeNames) nameMap = nameMap.copy(
localNames = nameMap.localNames!! + (funcCount to localNameMap.getAllNamesByIndex("local"))
)
funcCount++
mod = mod.copy(funcs = mod.funcs + fn).addTypeIfNotPresent(fn.type).first
mod = additionalFuncTypes.fold(mod) { mod, typ -> mod.addTypeIfNotPresent(typ).first }
}
}
"global" -> toGlobal(exp, nameMap).let { (_, glb, impExp) ->
if (impExp != null && impExp.importModule != null) {
handleImport(impExp.importModule, impExp.field, glb.type)
if (impExp is ImportOrExport.Import) {
handleImport(impExp.module, impExp.name, glb.type, impExp.exportFields)
} else {
addMaybeExport(impExp, Node.ExternalKind.GLOBAL, globalCount++)
if (impExp is ImportOrExport.Export) addExport(impExp, Node.ExternalKind.GLOBAL, globalCount)
globalCount++
mod = mod.copy(globals = mod.globals + glb)
}
}
"table" -> toTable(exp, nameMap).let { (_, tbl, impExp) ->
if (impExp != null && impExp.importModule != null) {
if (impExp is ImportOrExport.Import) {
if (tbl !is Either.Left) error("Elem segment on import table")
handleImport(impExp.importModule, impExp.field, tbl.v)
handleImport(impExp.module, impExp.name, tbl.v, impExp.exportFields)
} else {
addMaybeExport(impExp, Node.ExternalKind.TABLE, tableCount++)
if (impExp is ImportOrExport.Export) addExport(impExp, Node.ExternalKind.TABLE, tableCount)
tableCount++
when (tbl) {
is Either.Left -> mod = mod.copy(tables = mod.tables + tbl.v)
is Either.Right -> mod = mod.copy(
@ -549,11 +628,12 @@ open class SExprToAst {
}
}
"memory" -> toMemory(exp).let { (_, mem, impExp) ->
if (impExp != null && impExp.importModule != null) {
if (impExp is ImportOrExport.Import) {
if (mem !is Either.Left) error("Data segment on import mem")
handleImport(impExp.importModule, impExp.field, mem.v)
handleImport(impExp.module, impExp.name, mem.v, impExp.exportFields)
} else {
addMaybeExport(impExp, Node.ExternalKind.MEMORY, memoryCount++)
if (impExp is ImportOrExport.Export) addExport(impExp, Node.ExternalKind.MEMORY, memoryCount)
memoryCount++
when (mem) {
is Either.Left -> mod = mod.copy(memories = mod.memories + mem.v)
is Either.Right -> mod = mod.copy(
@ -574,12 +654,34 @@ open class SExprToAst {
if (mod.tables.size + mod.imports.count { it.kind is Node.Import.Kind.Table } > 1)
throw IoErr.MultipleTables()
// Set the name map pieces if we're including them
if (includeNames) mod = mod.copy(
names = Node.NameSection(
moduleName = name,
funcNames = nameMap.funcNames!!,
localNames = nameMap.localNames!!
)
)
return name to mod
}
fun toModuleFromBytes(bytes: ByteArray) = BinaryToAst.toModule(ByteReader.InputStream(ByteArrayInputStream(bytes)))
fun toModuleForwardNameMap(exps: List<SExpr.Multi>): NameMap {
fun toModuleFromQuotedString(str: String) = StrToSExpr.parse(str).let {
when (it) {
is StrToSExpr.ParseResult.Error -> error("Failed parsing quoted module: ${it.msg}")
is StrToSExpr.ParseResult.Success -> {
// If the result is not a single module sexpr, wrap it in one
val sexpr = it.vals.singleOrNull()?.let { it as? SExpr.Multi }?.takeIf {
it.vals.firstOrNull()?.symbolStr() == "module"
} ?: SExpr.Multi(listOf(SExpr.Symbol("module")) + it.vals)
toModule(sexpr)
}
}
}
fun toModuleForwardNameMapAndTypes(exps: List<SExpr.Multi>): Pair<NameMap, List<Node.Type.Func>> {
// We break into import and non-import because the index
// tables do imports first
val (importExps, nonImportExps) = exps.partition {
@ -597,9 +699,14 @@ open class SExprToAst {
var globalCount = 0
var tableCount = 0
var memoryCount = 0
var namesToIndices = emptyMap<String, Int>()
var nameMap = NameMap(
names = emptyMap(),
funcNames = if (includeNames) emptyMap() else null,
localNames = if (includeNames) emptyMap() else null
)
var types = emptyList<Node.Type.Func>()
fun maybeAddName(name: String?, index: Int, type: String) {
name?.let { namesToIndices += "$type:$it" to index }
name?.also { nameMap = nameMap.add(type, it, index) }
}
// All imports first
@ -625,10 +732,14 @@ open class SExprToAst {
"global" -> maybeAddName(kindName, globalCount++, "global")
"table" -> maybeAddName(kindName, tableCount++, "table")
"memory" -> maybeAddName(kindName, memoryCount++, "memory")
// We go ahead and do the full type def build here eagerly
"type" -> maybeAddName(kindName, types.size, "type").also { _ ->
toTypeDef(it, nameMap).also { (_, type) -> types += type }
}
else -> {}
}
}
return namesToIndices
return nameMap to types
}
fun toOpMaybe(exp: SExpr.Multi, offset: Int, ctx: ExprContext): Pair<Node.Instr, Int>? {
@ -661,7 +772,28 @@ open class SExprToAst {
Pair(op.create(vars.dropLast(1), vars.last()), offset + 1 + vars.size)
}
is InstrOp.CallOp.IndexArg -> Pair(op.create(oneVar("func")), 2)
is InstrOp.CallOp.IndexReservedArg -> Pair(op.create(oneVar("type"), false), 2)
is InstrOp.CallOp.IndexReservedArg -> {
// First lookup the func sig
val (updatedNameMap, expsUsed, funcType) = toFuncSig(exp, offset + 1, ctx.nameMap, ctx.types)
// Make sure there are no changes to the name map
if (ctx.nameMap.size != updatedNameMap.size)
throw IoErr.IndirectCallSetParamNames()
// Obtain the func index from the types table, the indirects table, or just add it
var funcTypeIndex = ctx.types.indexOf(funcType)
// If it's not in the type list, check the call indirect list
if (funcTypeIndex == -1) {
funcTypeIndex = ctx.callIndirectNeverBeforeSeenFuncTypes.indexOf(funcType)
// If it's not there either, add it as a fresh
if (funcTypeIndex == -1) {
funcTypeIndex = ctx.callIndirectNeverBeforeSeenFuncTypes.size
ctx.callIndirectNeverBeforeSeenFuncTypes += funcType
}
// And of course increase it by the overall type list size since they'll be added to that later
funcTypeIndex += ctx.types.size
}
Pair(op.create(funcTypeIndex, false), expsUsed + 1)
}
is InstrOp.ParamOp.NoArg -> Pair(op.create, 1)
is InstrOp.VarOp.IndexArg -> Pair(op.create(
oneVar(if (head.contents.endsWith("global")) "global" else "local")), 2)
@ -678,10 +810,10 @@ open class SExprToAst {
if (exp.vals.size > offset + count) exp.vals[offset + count].symbolStr().also {
if (it != null && it.startsWith("align=")) {
instrAlign = it.substring(6).toInt()
require(instrAlign > 0 && instrAlign and (instrAlign - 1) == 0) {
"Alignment expected to be positive power of 2, but got $instrAlign"
if (instrAlign <= 0 || instrAlign and (instrAlign - 1) != 0) {
throw IoErr.InvalidAlignPower(instrAlign)
}
if (instrAlign > op.argBits / 8) throw IoErr.InvalidAlign(instrAlign, op.argBits)
if (instrAlign > op.argBits / 8) throw IoErr.InvalidAlignTooLarge(instrAlign, op.argBits)
count++
}
}
@ -723,14 +855,18 @@ open class SExprToAst {
return Node.ResizableLimits(init.toInt(), max?.toInt())
}
fun toResult(exp: SExpr.Multi): Node.Type.Value {
fun toResult(exp: SExpr.Multi): List<Node.Type.Value> {
exp.requireFirstSymbol("result")
if (exp.vals.size > 2) throw IoErr.InvalidResultArity()
return toType(exp.vals[1].symbol()!!)
return exp.vals.drop(1).map { toType(it.symbol() ?: error("Invalid result type")) }
}
fun toScript(exp: SExpr.Multi): Script {
return Script(exp.vals.map { toCmd(it as SExpr.Multi) })
val cmds = exp.vals.map { toCmdMaybe(it as SExpr.Multi) }
// If the commands are non-empty but they are all null, it's an inline module
if (cmds.isNotEmpty() && cmds.all { it == null }) {
return toModule(exp).let { Script(listOf(Script.Cmd.Module(it.second, it.first))) }
}
return Script(cmds.filterNotNull())
}
fun toStart(exp: SExpr.Multi, nameMap: NameMap): Int {
@ -747,12 +883,12 @@ open class SExprToAst {
val name = exp.maybeName(currIndex)
if (name != null) currIndex++
val maybeImpExp = toImportOrExportMaybe(exp, currIndex)
if (maybeImpExp != null) currIndex++
maybeImpExp?.also { currIndex += it.itemCount }
// If elem type is there, we load the elems instead
val elemType = toElemTypeMaybe(exp, currIndex)
val tableOrElems =
if (elemType != null) {
require(maybeImpExp?.importModule == null)
require(maybeImpExp !is ImportOrExport.Import)
val elem = exp.vals[currIndex + 1] as SExpr.Multi
elem.requireFirstSymbol("elem")
Either.Right(Node.Elem(
@ -770,7 +906,7 @@ open class SExprToAst {
}
fun toType(exp: SExpr.Symbol): Node.Type.Value {
return toTypeMaybe(exp) ?: throw Exception("Unknown value type: ${exp.contents}")
return toTypeMaybe(exp) ?: throw IoErr.InvalidType(exp.contents)
}
fun toTypeMaybe(exp: SExpr.Symbol): Node.Type.Value? = when(exp.contents) {
@ -792,50 +928,87 @@ open class SExprToAst {
}
fun toVar(exp: SExpr.Symbol, nameMap: NameMap, nameType: String): Int {
return toVarMaybe(exp, nameMap, nameType) ?: throw Exception("No var for on exp $exp")
return toVarMaybe(exp, nameMap, nameType) ?: throw IoErr.InvalidVar(exp.toString())
}
fun toVarMaybe(exp: SExpr, nameMap: NameMap, nameType: String): Int? {
return exp.symbolStr()?.let { it ->
if (it.startsWith("$"))
nameMap["$nameType:$it"] ?:
nameMap.get(nameType, it.drop(1)) ?:
throw Exception("Unable to find index for name $it of type $nameType in $nameMap")
else if (it.startsWith("0x")) it.substring(2).toIntOrNull(16)
else it.toIntOrNull()
}
}
private fun String.sansUnderscores(): String {
// The underscores can only be between digits (which can be hex)
fun isDigit(c: Char) = c.isDigit() || (startsWith("0x", true) && (c in 'a'..'f' || c in 'A'..'F'))
var ret = this
var underscoreIndex = 0
while (true){
underscoreIndex = ret.indexOf('_', underscoreIndex)
if (underscoreIndex == -1) return ret
// Can't be at beginning or end
if (underscoreIndex == 0 || underscoreIndex == ret.length - 1 ||
!isDigit(ret[underscoreIndex - 1]) || !isDigit(ret[underscoreIndex + 1])) {
throw IoErr.ConstantUnknownOperator(this)
}
ret = ret.removeRange(underscoreIndex, underscoreIndex + 1)
}
}
private fun String.toBigIntegerConst() =
if (this.contains("0x")) BigInteger(this.replace("0x", ""), 16)
if (contains("0x")) BigInteger(replace("0x", ""), 16)
else BigInteger(this)
private fun String.toIntConst() = toBigIntegerConst().toInt()
private fun String.toLongConst() = toBigIntegerConst().toLong()
private fun String.toUnsignedIntConst() =
(if (this.contains("0x")) Long.valueOf(this.replace("0x", ""), 16)
private fun String.toIntConst() = sansUnderscores().run {
toBigIntegerConst().
also { if (it > MAX_UINT32) throw IoErr.ConstantOutOfRange(it) }.
also { if (it < MIN_INT32) throw IoErr.ConstantOutOfRange(it) }.
toInt()
}
private fun String.toLongConst() = sansUnderscores().run {
toBigIntegerConst().
also { if (it > MAX_UINT64) throw IoErr.ConstantOutOfRange(it) }.
also { if (it < MIN_INT64) throw IoErr.ConstantOutOfRange(it) }.
toLong()
}
private fun String.toUnsignedIntConst() = sansUnderscores().run {
(if (contains("0x")) Long.valueOf(replace("0x", ""), 16)
else Long.valueOf(this)).unsignedToSignedInt().toUnsignedLong()
}
private fun String.toFloatConst() =
if (this == "infinity" || this == "+infinity") Float.POSITIVE_INFINITY
else if (this == "-infinity") Float.NEGATIVE_INFINITY
private fun String.toFloatConst() = sansUnderscores().run {
if (this == "infinity" || this == "+infinity" || this == "inf" || this == "+inf") Float.POSITIVE_INFINITY
else if (this == "-infinity" || this == "-inf") Float.NEGATIVE_INFINITY
else if (this == "nan" || this == "+nan") Float.fromIntBits(0x7fc00000)
else if (this == "-nan") Float.fromIntBits(0xffc00000.toInt())
else if (this.startsWith("nan:") || this.startsWith("+nan:")) Float.fromIntBits(
0x7f800000 + this.substring(this.indexOf(':') + 1).toIntConst()
) else if (this.startsWith("-nan:")) Float.fromIntBits(
0xff800000.toInt() + this.substring(this.indexOf(':') + 1).toIntConst()
) else if (this.startsWith("0x") && !this.contains('P', true)) this.toLongConst().toFloat()
else this.toFloat()
private fun String.toDoubleConst() =
if (this == "infinity" || this == "+infinity") Double.POSITIVE_INFINITY
else if (this == "-infinity") Double.NEGATIVE_INFINITY
) else {
// If there is no "p" on a hex, we have to add it
var str = this
if (str.startsWith("0x", true) && !str.contains('P', true)) str += "p0"
str.toFloat().also { if (it.isInfinite()) throw IoErr.ConstantOutOfRange(it) }
}
}
private fun String.toDoubleConst() = sansUnderscores().run {
if (this == "infinity" || this == "+infinity" || this == "inf" || this == "+inf") Double.POSITIVE_INFINITY
else if (this == "-infinity" || this == "-inf") Double.NEGATIVE_INFINITY
else if (this == "nan" || this == "+nan") Double.fromLongBits(0x7ff8000000000000)
else if (this == "-nan") Double.fromLongBits(-2251799813685248) // i.e. 0xfff8000000000000
else if (this.startsWith("nan:") || this.startsWith("+nan:")) Double.fromLongBits(
0x7ff0000000000000 + this.substring(this.indexOf(':') + 1).toLongConst()
) else if (this.startsWith("-nan:")) Double.fromLongBits(
-4503599627370496 + this.substring(this.indexOf(':') + 1).toLongConst() // i.e. 0xfff0000000000000
) else if (this.startsWith("0x") && !this.contains('P', true)) this.toLongConst().toDouble()
else this.toDouble()
) else {
// If there is no "p" on a hex, we have to add it
var str = this
if (str.startsWith("0x", true) && !str.contains('P', true)) str += "p0"
str.toDouble().also { if (it.isInfinite()) throw IoErr.ConstantOutOfRange(it) }
}
}
private fun SExpr.requireSymbol(contents: String, quotedCheck: Boolean? = null) {
if (this is SExpr.Symbol && this.contents == contents &&
@ -848,11 +1021,15 @@ open class SExprToAst {
private fun SExpr.symbol() = this as? SExpr.Symbol
private fun SExpr.symbolStr() = this.symbol()?.contents
private fun SExpr.symbolUtf8Str() = this.symbol()?.let {
if (it.hasNonUtf8ByteSeqs) throw IoErr.InvalidUtf8Encoding()
it.contents
}
private fun SExpr.Multi.maybeName(index: Int): String? {
if (this.vals.size > index && this.vals[index] is SExpr.Symbol) {
val sym = this.vals[index] as SExpr.Symbol
if (!sym.quoted && sym.contents[0] == '$') return sym.contents
if (!sym.quoted && sym.contents[0] == '$') return sym.contents.drop(1)
}
return null
}
@ -875,5 +1052,26 @@ open class SExprToAst {
return this.vals.first().requireSymbol(contents, quotedCheck)
}
data class NameMap(
// Key prefixed with type then colon before actual name
val names: Map<String, Int>,
// Null if not including names
val funcNames: Map<Int, String>?,
val localNames: Map<Int, Map<Int, String>>?
) {
val size get() = names.size
fun add(type: String, name: String, index: Int) = copy(
names = names + ("$type:$name" to index),
funcNames = funcNames?.let { if (type == "func") it + (index to name) else it }
)
fun get(type: String, name: String) = names["$type:$name"]
fun getAllNamesByIndex(type: String) = names.mapNotNull { (k, v) ->
k.takeIf { k.startsWith("$type:") }?.let { v to k.substring(type.length + 1) }
}.toMap()
}
companion object : SExprToAst()
}

View File

@ -4,16 +4,15 @@ import asmble.ast.SExpr
open class SExprToStr(val depthBeforeNewline: Int, val countBeforeNewlineAll: Int, val indent: String) {
@Suppress("UNCHECKED_CAST") // TODO: why?
fun fromSExpr(vararg exp: SExpr): String = appendAll(exp.asList(), StringBuilder()).trim().toString()
@Suppress("UNCHECKED_CAST") // TODO: why?
@Suppress("UNCHECKED_CAST")
fun <T : Appendable> append(exp: SExpr, sb: T = StringBuilder() as T, indentLevel: Int = 0) = when(exp) {
is SExpr.Symbol -> appendSymbol(exp, sb)
is SExpr.Multi -> appendMulti(exp, sb, indentLevel)
}
@Suppress("UNCHECKED_CAST") // TODO: why?
@Suppress("UNCHECKED_CAST")
fun <T : Appendable> appendSymbol(exp: SExpr.Symbol, sb: T = StringBuilder() as T): T {
val quoted = exp.quoted || exp.contents.requiresQuote
if (!quoted) sb.append(exp.contents) else {
@ -33,7 +32,7 @@ open class SExprToStr(val depthBeforeNewline: Int, val countBeforeNewlineAll: In
return sb
}
@Suppress("UNCHECKED_CAST") // TODO: why?
@Suppress("UNCHECKED_CAST")
fun <T : Appendable> appendMulti(exp: SExpr.Multi, sb: T = StringBuilder() as T, indentLevel: Int = 0): T {
sb.append('(')
appendAll(exp.vals, sb, indentLevel)
@ -41,7 +40,7 @@ open class SExprToStr(val depthBeforeNewline: Int, val countBeforeNewlineAll: In
return sb
}
@Suppress("UNCHECKED_CAST") // TODO: why?
@Suppress("UNCHECKED_CAST")
fun <T : Appendable> appendAll(exps: List<SExpr>, sb: T = StringBuilder() as T, indentLevel: Int = 0): T {
val newlineAll = exps.sumBy { it.count() } >= countBeforeNewlineAll
var wasLastNewline = false

View File

@ -1,6 +1,8 @@
package asmble.io
import asmble.ast.SExpr
import java.nio.ByteBuffer
import java.nio.charset.CharacterCodingException
open class StrToSExpr {
sealed class ParseResult {
@ -16,6 +18,13 @@ open class StrToSExpr {
data class Error(val pos: Pos, val msg: String) : ParseResult()
}
fun parseSingleMulti(str: CharSequence) = parse(str).let {
when (it) {
is ParseResult.Success -> (it.vals.singleOrNull() as? SExpr.Multi) ?: error("Not a single multi-expr")
is ParseResult.Error -> error("Failed parsing at ${it.pos.line}:${it.pos.char} - ${it.msg}")
}
}
fun parse(str: CharSequence): ParseResult {
val state = ParseState(str)
val ret = mutableListOf<SExpr>()
@ -53,9 +62,29 @@ open class StrToSExpr {
}
'"' -> {
offset++
// Check escapes
// We go over each char here checking escapes
var retStr = ""
// The WASM spec says we can treat chars normally unless they are hex escapes at which point they
// are raw bytes. Since we want to store everything as a string for later use, we need to keep track
// which set of raw bytes were invalid UTF-8 for UTF-8 validation later. Alternatively, we could
// just store in bytes and decode on use but this is easier. We keep a list of byte "runs" and at
// the end of each "run", we check whether they would make a valid UTF-8 string.
var hasNonUtf8ByteSeqs = false
val currByteSeq = mutableListOf<Byte>()
fun checkByteSeq() {
if (!hasNonUtf8ByteSeqs && currByteSeq.isNotEmpty()) {
try {
Charsets.UTF_8.newDecoder().decode(ByteBuffer.wrap(currByteSeq.toByteArray()))
} catch (_: CharacterCodingException) {
hasNonUtf8ByteSeqs = true
}
currByteSeq.clear()
}
}
while (err == null && !isEof && str[offset] != '"') {
var wasEscapedChar = false
if (str[offset] == '\\') {
offset++
if (isEof) err = "EOF when expected char to unescape" else {
@ -69,7 +98,10 @@ open class StrToSExpr {
// Try to parse hex if there is enough, otherwise just gripe
if (offset + 1 >= str.length) err = "Not enough to hex escape" else {
try {
retStr += str.substring(offset, offset + 2).toInt(16).toChar()
val int = str.substring(offset, offset + 2).toInt(16)
retStr += int.toChar()
currByteSeq.add(int.toByte())
wasEscapedChar = true
offset++
} catch (e: NumberFormatException) {
err = "Unknown escape: ${str.substring(offset, offset + 2)}: $e"
@ -83,10 +115,12 @@ open class StrToSExpr {
retStr += str[offset]
offset++
}
if (!wasEscapedChar) checkByteSeq()
}
checkByteSeq()
if (err == null && str[offset] != '"') err = "EOF when expected '\"'"
else if (err == null) offset++
val ret = SExpr.Symbol(retStr, true)
val ret = SExpr.Symbol(retStr, true, hasNonUtf8ByteSeqs)
exprOffsetMap.put(System.identityHashCode(ret), origOffset)
return ret
}

View File

@ -3,6 +3,7 @@ package asmble.run.jvm
import asmble.AsmErr
import java.lang.invoke.WrongMethodTypeException
import java.nio.BufferOverflowException
import java.nio.charset.MalformedInputException
open class ExceptionTranslator {
fun translate(ex: Throwable): List<String> = when (ex) {
@ -13,11 +14,12 @@ open class ExceptionTranslator {
is ArrayIndexOutOfBoundsException -> listOf("undefined element", "elements segment does not fit")
is AsmErr -> ex.asmErrStrings
is IndexOutOfBoundsException -> listOf("out of bounds memory access")
is NoSuchMethodException -> listOf("unknown import", "type mismatch")
is MalformedInputException -> listOf("invalid UTF-8 encoding")
is NullPointerException -> listOf("undefined element", "uninitialized element")
is StackOverflowError -> listOf("call stack exhausted")
is UnsupportedOperationException -> listOf("unreachable executed")
is WrongMethodTypeException -> listOf("indirect call signature mismatch")
is WrongMethodTypeException -> listOf("indirect call type mismatch")
is NumberFormatException -> listOf("i32 constant")
else -> emptyList()
}

View File

@ -0,0 +1,171 @@
package asmble.run.jvm
import asmble.annotation.WasmExport
import asmble.annotation.WasmExternalKind
import asmble.ast.Node
import asmble.compile.jvm.Mem
import asmble.compile.jvm.ref
import java.lang.invoke.MethodHandle
import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType
import java.lang.reflect.Constructor
import java.lang.reflect.Modifier
interface Module {
fun bindMethod(
ctx: ScriptContext,
wasmName: String,
wasmKind: WasmExternalKind,
javaName: String,
type: MethodType
): MethodHandle?
data class Composite(val modules: List<Module>) : Module {
override fun bindMethod(
ctx: ScriptContext,
wasmName: String,
wasmKind: WasmExternalKind,
javaName: String,
type: MethodType
) = modules.asSequence().mapNotNull { it.bindMethod(ctx, wasmName, wasmKind, javaName, type) }.singleOrNull()
}
interface Instance : Module {
val cls: Class<*>
// Guaranteed to be the same instance when there is no error
fun instance(ctx: ScriptContext): Any
override fun bindMethod(
ctx: ScriptContext,
wasmName: String,
wasmKind: WasmExternalKind,
javaName: String,
type: MethodType
) = cls.methods.filter {
// @WasmExport match or just javaName match
Modifier.isPublic(it.modifiers) &&
!Modifier.isStatic(it.modifiers) &&
it.getDeclaredAnnotation(WasmExport::class.java).let { ann ->
if (ann == null) it.name == javaName else ann.value == wasmName && ann.kind == wasmKind
}
}.mapNotNull {
MethodHandles.lookup().unreflect(it).bindTo(instance(ctx)).takeIf { it.type() == type }
}.singleOrNull()
}
data class Native(override val cls: Class<*>, val inst: Any) : Instance {
constructor(inst: Any) : this(inst::class.java, inst)
override fun instance(ctx: ScriptContext) = inst
}
class Compiled(
val mod: Node.Module,
override val cls: Class<*>,
val name: String?,
val mem: Mem
) : Instance {
private var inst: Any? = null
override fun instance(ctx: ScriptContext) =
synchronized(this) { inst ?: createInstance(ctx).also { inst = it } }
private fun createInstance(ctx: ScriptContext): Any {
// Find the constructor
var constructorParams = emptyList<Any>()
var constructor: Constructor<*>?
// If there is a memory import, we have to get the one with the mem class as the first
val memImport = mod.imports.find { it.kind is Node.Import.Kind.Memory }
val memLimit = if (memImport != null) {
constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull()?.ref == mem.memType }
val memImportKind = memImport.kind as Node.Import.Kind.Memory
val memInst = ctx.resolveImportMemory(memImport, memImportKind.type, mem)
constructorParams += memInst
val (memLimit, memCap) = mem.limitAndCapacity(memInst)
if (memLimit < memImportKind.type.limits.initial * Mem.PAGE_SIZE)
throw RunErr.ImportMemoryLimitTooSmall(memImportKind.type.limits.initial * Mem.PAGE_SIZE, memLimit)
memImportKind.type.limits.maximum?.let {
if (memCap > it * Mem.PAGE_SIZE)
throw RunErr.ImportMemoryCapacityTooLarge(it * Mem.PAGE_SIZE, memCap)
}
memLimit
} else {
// Find the constructor with no max mem amount (i.e. not int and not memory)
constructor = cls.declaredConstructors.find {
val memClass = Class.forName(mem.memType.asm.className)
when (it.parameterTypes.firstOrNull()) {
Int::class.java, memClass -> false
else -> true
}
}
// If it is not there, find the one w/ the max mem amount
val maybeMem = mod.memories.firstOrNull()
if (constructor == null) {
val maxMem = Math.max(maybeMem?.limits?.initial ?: 0, ctx.defaultMaxMemPages)
constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull() == Int::class.java }
constructorParams += maxMem * Mem.PAGE_SIZE
}
maybeMem?.limits?.initial?.let { it * Mem.PAGE_SIZE }
}
if (constructor == null) error("Unable to find suitable module constructor")
// Function imports
constructorParams += mod.imports.mapNotNull {
if (it.kind is Node.Import.Kind.Func) ctx.resolveImportFunc(it, mod.types[it.kind.typeIndex])
else null
}
// Global imports
val globalImports = mod.imports.flatMap {
if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobals(it, it.kind.type)
else emptyList()
}
constructorParams += globalImports
// Table imports
val tableImport = mod.imports.find { it.kind is Node.Import.Kind.Table }
val tableSize = if (tableImport != null) {
val tableImportKind = tableImport.kind as Node.Import.Kind.Table
val table = ctx.resolveImportTable(tableImport, tableImportKind.type)
if (table.size < tableImportKind.type.limits.initial)
throw RunErr.ImportTableTooSmall(tableImportKind.type.limits.initial, table.size)
tableImportKind.type.limits.maximum?.let {
if (table.size > it) throw RunErr.ImportTableTooLarge(it, table.size)
}
constructorParams = constructorParams.plusElement(table)
table.size
} else mod.tables.firstOrNull()?.limits?.initial
// We need to validate that elems can fit in table and data can fit in mem
fun constIntExpr(insns: List<Node.Instr>): Int? = insns.singleOrNull()?.let {
when (it) {
is Node.Instr.I32Const -> it.value
is Node.Instr.GetGlobal ->
if (it.index < globalImports.size) {
// Imports we already have
if (globalImports[it.index].type().returnType() == Int::class.java) {
globalImports[it.index].invokeWithArguments() as Int
} else null
} else constIntExpr(mod.globals[it.index - globalImports.size].init)
else -> null
}
}
if (tableSize != null) mod.elems.forEach { elem ->
constIntExpr(elem.offset)?.let { offset ->
if (offset + elem.funcIndices.size > tableSize)
throw RunErr.InvalidElemIndex(offset, elem.funcIndices.size, tableSize)
}
}
if (memLimit != null) mod.data.forEach { data ->
constIntExpr(data.offset)?.let { offset ->
if (offset < 0 || offset + data.data.size > memLimit)
throw RunErr.InvalidDataIndex(offset, data.data.size, memLimit)
}
}
// Construct
ctx.debug { "Instantiating $cls using $constructor with params $constructorParams" }
return constructor.newInstance(*constructorParams.toTypedArray())
}
}
}

View File

@ -8,14 +8,14 @@ sealed class RunErr(message: String, cause: Throwable? = null) : RuntimeExceptio
val expected: Int,
val actual: Int
) : RunErr("Import memory limit $actual but expecting at least $expected") {
override val asmErrString get() = "actual size smaller than declared"
override val asmErrString get() = "incompatible import type"
}
class ImportMemoryCapacityTooLarge(
val expected: Int,
val actual: Int
) : RunErr("Import table capacity $actual but expecting no more than $expected") {
override val asmErrString get() = "maximum size larger than declared"
override val asmErrString get() = "incompatible import type"
}
class InvalidDataIndex(
@ -30,20 +30,37 @@ sealed class RunErr(message: String, cause: Throwable? = null) : RuntimeExceptio
val expected: Int,
val actual: Int
) : RunErr("Import table sized $actual but expecting at least $expected") {
override val asmErrString get() = "actual size smaller than declared"
override val asmErrString get() = "incompatible import type"
}
class ImportTableTooLarge(
val expected: Int,
val actual: Int
) : RunErr("Import table sized $actual but expecting no more than $expected") {
override val asmErrString get() = "maximum size larger than declared"
override val asmErrString get() = "incompatible import type"
}
class InvalidElemIndex(
val index: Int,
val offset: Int,
val elemSize: Int,
val tableSize: Int
) : RunErr("Trying to set elem at index $index but table size is only $tableSize") {
) : RunErr("Trying to set $elemSize elems at offset $offset but table size is only $tableSize") {
override val asmErrString get() = "elements segment does not fit"
}
class ImportNotFound(
val module: String,
val field: String
) : RunErr("Cannot find compatible import for $module::$field") {
override val asmErrString get() = "unknown import"
override val asmErrStrings get() = listOf(asmErrString, "incompatible import type")
}
class ImportGlobalInvalidMutability(
val module: String,
val field: String,
val expected: Boolean
) : RunErr("Expected imported global $module::$field to have mutability as ${!expected}") {
override val asmErrString get() = "incompatible import type"
}
}

View File

@ -1,11 +1,11 @@
package asmble.run.jvm
import asmble.annotation.WasmExternalKind
import asmble.ast.Node
import asmble.ast.Script
import asmble.compile.jvm.*
import asmble.io.AstToSExpr
import asmble.io.SExprToStr
import asmble.run.jvm.annotation.WasmName
import asmble.util.Logger
import asmble.util.toRawIntBits
import asmble.util.toRawLongBits
@ -16,25 +16,39 @@ import java.io.PrintWriter
import java.lang.invoke.MethodHandle
import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType
import java.lang.reflect.Constructor
import java.lang.reflect.InvocationTargetException
import java.util.*
/**
* Script context. Contains all the information needed to execute this script.
*
* @param packageName Package name for this script
* @param modules List of all modules to be load for this script
* @param registrations Registered modules, key - module name, value - module instance
* @param logger Logger for this script
* @param adjustContext Fn for tuning context (looks redundant)
* @param classLoader ClassLoader for loading all classes for this script
* @param exceptionTranslator Converts exceptions to error messages
* @param defaultMaxMemPages The maximum number of memory pages when a module doesn't say
* @param includeBinaryInCompiledClass Store binary wasm code to compiled class
* file as annotation [asmble.annotation.WasmModule]
*/
data class ScriptContext(
val packageName: String,
val modules: List<CompiledModule> = emptyList(),
val modules: List<Module.Compiled> = emptyList(),
val registrations: Map<String, Module> = emptyMap(),
val logger: Logger = Logger.Print(Logger.Level.OFF),
val adjustContext: (ClsContext) -> ClsContext = { it },
val classLoader: SimpleClassLoader =
ScriptContext.SimpleClassLoader(ScriptContext::class.java.classLoader, logger),
val exceptionTranslator: ExceptionTranslator = ExceptionTranslator,
val defaultMaxMemPages: Int = 1
val defaultMaxMemPages: Int = 1,
val includeBinaryInCompiledClass: Boolean = false
) : Logger by logger {
fun withHarnessRegistered(out: PrintWriter = PrintWriter(System.out, true)) =
copy(registrations = registrations + (
"spectest" to NativeModule(TestHarness::class.java, TestHarness(out))
))
withModuleRegistered("spectest", Module.Native(TestHarness(out)))
fun withModuleRegistered(name: String, mod: Module) = copy(registrations = registrations + (name to mod))
fun runCommand(cmd: Script.Cmd) = when (cmd) {
is Script.Cmd.Module ->
@ -106,7 +120,6 @@ data class ScriptContext(
}
fun assertReturnNan(ret: Script.Cmd.Assertion.ReturnNan) {
// TODO: validate canonical vs arithmetic
val (retType, retVal) = doAction(ret.action)
when (retType) {
Node.Type.Value.F32 ->
@ -177,7 +190,12 @@ data class ScriptContext(
val msgs = exceptionTranslator.translate(innerEx)
if (msgs.isEmpty())
throw ScriptAssertionError(a, "Expected failure '$expectedString' but got unknown err", cause = innerEx)
if (!msgs.any { it.contains(expectedString) })
var msgToFind = expectedString
// Special case for "uninitialized element" error match. This is because the error is expected to
// be "uninitialized number #" where # is the indirect call number. But it is at runtime where this fails
// so it is not worth it for us to store the index of failure. So we generalize it.
if (msgToFind.startsWith("uninitialized element")) msgToFind = "uninitialized element"
if (!msgs.any { it.contains(msgToFind) })
throw ScriptAssertionError(a, "Expected failure '$expectedString' in $msgs", cause = innerEx)
}
@ -247,172 +265,60 @@ data class ScriptContext(
fun withCompiledModule(mod: Node.Module, className: String, name: String?) =
copy(modules = modules + compileModule(mod, className, name))
fun compileModule(mod: Node.Module, className: String, name: String?): CompiledModule {
fun compileModule(mod: Node.Module, className: String, name: String?): Module.Compiled {
val ctx = ClsContext(
packageName = packageName,
className = className,
mod = mod,
logger = logger
logger = logger,
includeBinary = includeBinaryInCompiledClass
).let(adjustContext)
AstToAsm.fromModule(ctx)
return CompiledModule(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem)
return Module.Compiled(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem)
}
fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType): MethodHandle {
fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType) = bindImport(
import, if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent, methodType)
fun bindImport(import: Node.Import, javaName: String, methodType: MethodType): MethodHandle {
// Find a method that matches our expectations
val module = registrations[import.module] ?: error("Unable to find module ${import.module}")
// TODO: do I want to introduce a complicated set of code that will find
// a method that can accept the given params including varargs, boxing, etc?
// I doubt it since it's only the JVM layer, WASM doesn't have parametric polymorphism
try {
val javaName = if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent
return MethodHandles.lookup().bind(module.instance(this), javaName, methodType)
} catch (e: NoSuchMethodException) {
// Try any method w/ the proper annotation
module.cls.methods.forEach { method ->
if (method.getAnnotation(WasmName::class.java)?.value == import.field) {
val handle = MethodHandles.lookup().unreflect(method).bindTo(module.instance(this))
if (handle.type() == methodType) return handle
}
}
throw e
val module = registrations[import.module] ?: throw RunErr.ImportNotFound(import.module, import.field)
val kind = when (import.kind) {
is Node.Import.Kind.Func -> WasmExternalKind.FUNCTION
is Node.Import.Kind.Table -> WasmExternalKind.TABLE
is Node.Import.Kind.Memory -> WasmExternalKind.MEMORY
is Node.Import.Kind.Global -> WasmExternalKind.GLOBAL
}
return module.bindMethod(this, import.field, kind, javaName, methodType) ?:
throw RunErr.ImportNotFound(import.module, import.field)
}
fun resolveImportFunc(import: Node.Import, funcType: Node.Type.Func) =
bindImport(import, false,
MethodType.methodType(funcType.ret?.jclass ?: Void.TYPE, funcType.params.map { it.jclass }))
fun resolveImportGlobal(import: Node.Import, globalType: Node.Type.Global) =
bindImport(import, true, MethodType.methodType(globalType.contentType.jclass))
fun resolveImportGlobals(import: Node.Import, globalType: Node.Type.Global): List<MethodHandle> {
val getter = bindImport(import, true, MethodType.methodType(globalType.contentType.jclass))
// Whether the setter is present or not defines whether it is mutable
val setter = try {
bindImport(import, "set" + import.field.javaIdent.capitalize(),
MethodType.methodType(Void.TYPE, globalType.contentType.jclass))
} catch (e: RunErr.ImportNotFound) { null }
// Mutability must match
if (globalType.mutable == (setter == null))
throw RunErr.ImportGlobalInvalidMutability(import.module, import.field, globalType.mutable)
return if (setter == null) listOf(getter) else listOf(getter, setter)
}
fun resolveImportMemory(import: Node.Import, memoryType: Node.Type.Memory, mem: Mem) =
bindImport(import, true, MethodType.methodType(Class.forName(mem.memType.asm.className))).
invokeWithArguments()!!
@Suppress("UNCHECKED_CAST")
fun resolveImportTable(import: Node.Import, tableType: Node.Type.Table) =
bindImport(import, true, MethodType.methodType(Array<MethodHandle>::class.java)).
invokeWithArguments()!! as Array<MethodHandle>
interface Module {
val cls: Class<*>
// Guaranteed to be the same instance when there is no error
fun instance(ctx: ScriptContext): Any
}
class NativeModule(override val cls: Class<*>, val inst: Any) : Module {
override fun instance(ctx: ScriptContext) = inst
}
class CompiledModule(
val mod: Node.Module,
override val cls: Class<*>,
val name: String?,
val mem: Mem
) : Module {
private var inst: Any? = null
override fun instance(ctx: ScriptContext) =
synchronized(this) { inst ?: createInstance(ctx).also { inst = it } }
private fun createInstance(ctx: ScriptContext): Any {
// Find the constructor
var constructorParams = emptyList<Any>()
var constructor: Constructor<*>?
// If there is a memory import, we have to get the one with the mem class as the first
val memImport = mod.imports.find { it.kind is Node.Import.Kind.Memory }
val memLimit = if (memImport != null) {
constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull()?.ref == mem.memType }
val memImportKind = memImport.kind as Node.Import.Kind.Memory
val memInst = ctx.resolveImportMemory(memImport, memImportKind.type, mem)
constructorParams += memInst
val (memLimit, memCap) = mem.limitAndCapacity(memInst)
if (memLimit < memImportKind.type.limits.initial * Mem.PAGE_SIZE)
throw RunErr.ImportMemoryLimitTooSmall(memImportKind.type.limits.initial * Mem.PAGE_SIZE, memLimit)
memImportKind.type.limits.maximum?.let {
if (memCap > it * Mem.PAGE_SIZE)
throw RunErr.ImportMemoryCapacityTooLarge(it * Mem.PAGE_SIZE, memCap)
}
memLimit
} else {
// Find the constructor with no max mem amount (i.e. not int and not memory)
constructor = cls.declaredConstructors.find {
val memClass = Class.forName(mem.memType.asm.className)
when (it.parameterTypes.firstOrNull()) {
Int::class.java, memClass -> false
else -> true
}
}
// If it is not there, find the one w/ the max mem amount
val maybeMem = mod.memories.firstOrNull()
if (constructor == null) {
val maxMem = Math.max(maybeMem?.limits?.initial ?: 0, ctx.defaultMaxMemPages)
constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull() == Int::class.java }
constructorParams += maxMem * Mem.PAGE_SIZE
}
maybeMem?.limits?.initial?.let { it * Mem.PAGE_SIZE }
}
if (constructor == null) error("Unable to find suitable module constructor")
// Function imports
constructorParams += mod.imports.mapNotNull {
if (it.kind is Node.Import.Kind.Func) ctx.resolveImportFunc(it, mod.types[it.kind.typeIndex])
else null
}
// Global imports
val globalImports = mod.imports.mapNotNull {
if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobal(it, it.kind.type)
else null
}
constructorParams += globalImports
// Table imports
val tableImport = mod.imports.find { it.kind is Node.Import.Kind.Table }
val tableSize = if (tableImport != null) {
val tableImportKind = tableImport.kind as Node.Import.Kind.Table
val table = ctx.resolveImportTable(tableImport, tableImportKind.type)
if (table.size < tableImportKind.type.limits.initial)
throw RunErr.ImportTableTooSmall(tableImportKind.type.limits.initial, table.size)
tableImportKind.type.limits.maximum?.let {
if (table.size > it) throw RunErr.ImportTableTooLarge(it, table.size)
}
constructorParams = constructorParams.plusElement(table)
table.size
} else mod.tables.firstOrNull()?.limits?.initial
// We need to validate that elems can fit in table and data can fit in mem
fun constIntExpr(insns: List<Node.Instr>): Int? = insns.singleOrNull()?.let {
when (it) {
is Node.Instr.I32Const -> it.value
is Node.Instr.GetGlobal ->
if (it.index < globalImports.size) {
// Imports we already have
if (globalImports[it.index].type().returnType() == Int::class.java) {
globalImports[it.index].invokeWithArguments() as Int
} else null
} else constIntExpr(mod.globals[it.index - globalImports.size].init)
else -> null
}
}
if (tableSize != null) mod.elems.forEach { elem ->
constIntExpr(elem.offset)?.let { offset ->
if (offset >= tableSize) throw RunErr.InvalidElemIndex(offset, tableSize)
}
}
if (memLimit != null) mod.data.forEach { data ->
constIntExpr(data.offset)?.let { offset ->
if (offset < 0 || offset + data.data.size > memLimit)
throw RunErr.InvalidDataIndex(offset, data.data.size, memLimit)
}
}
// Construct
ctx.debug { "Instantiating $cls using $constructor with params $constructorParams" }
return constructor.newInstance(*constructorParams.toTypedArray())
}
}
open class SimpleClassLoader(parent: ClassLoader, logger: Logger) : ClassLoader(parent), Logger by logger {
fun fromBuiltContext(ctx: ClsContext): Class<*> {
trace { "Computing frames for ASM class:\n" + ctx.cls.toAsmString() }

View File

@ -1,7 +1,8 @@
package asmble.run.jvm
import asmble.annotation.WasmExport
import asmble.annotation.WasmExternalKind
import asmble.compile.jvm.Mem
import asmble.run.jvm.annotation.WasmName
import java.io.PrintWriter
import java.lang.invoke.MethodHandle
import java.nio.ByteBuffer
@ -11,10 +12,10 @@ open class TestHarness(val out: PrintWriter) {
// WASM is evil, not me:
// https://github.com/WebAssembly/spec/blob/6a01dab6d29b7c2b5dfd3bb3879bbd6ab76fd5dc/interpreter/host/import/spectest.ml#L12
@get:WasmName("global") val globalInt = 666
@get:WasmName("global") val globalLong = 666L
@get:WasmName("global") val globalFloat = 666.6f
@get:WasmName("global") val globalDouble = 666.6
val global_i32 = 666
val global_i64 = 666L
val global_f32 = 666.6f
val global_f64 = 666.6
val table = arrayOfNulls<MethodHandle>(10)
val memory = ByteBuffer.
allocateDirect(2 * Mem.PAGE_SIZE).
@ -26,12 +27,12 @@ open class TestHarness(val out: PrintWriter) {
// mh-lookup-bind. It does not support varargs, boxing, or any of
// that currently.
fun print() { }
fun print(arg0: Int) { out.println("$arg0 : i32") }
fun print(arg0: Long) { out.println("$arg0 : i64") }
fun print(arg0: Float) { out.printf("%#.0f : f32", arg0).println() }
fun print(arg0: Double) { out.printf("%#.0f : f64", arg0).println() }
fun print(arg0: Int, arg1: Float) { print(arg0); print(arg1) }
fun print(arg0: Double, arg1: Double) { print(arg0); print(arg1) }
fun print_i32(arg0: Int) { out.println("$arg0 : i32") }
fun print_i64(arg0: Long) { out.println("$arg0 : i64") }
fun print_f32(arg0: Float) { out.printf("%#.0f : f32", arg0).println() }
fun print_f64(arg0: Double) { out.printf("%#.0f : f64", arg0).println() }
fun print_i32_f32(arg0: Int, arg1: Float) { print(arg0); print(arg1) }
fun print_f64_f64(arg0: Double, arg1: Double) { print(arg0); print(arg1) }
companion object : TestHarness(PrintWriter(System.out, true))
}

View File

@ -0,0 +1,5 @@
package asmble.util
import java.nio.ByteBuffer
fun ByteBuffer.get(index: Int, bytes: ByteArray) = this.duplicate().also { it.position(index) }.get(bytes)

View File

@ -1,8 +1,13 @@
package asmble.util
import java.math.BigDecimal
import java.math.BigInteger
internal const val INT_MASK = 0xffffffffL
internal val MAX_UINT32 = BigInteger("ffffffff", 16)
internal val MIN_INT32 = BigInteger.valueOf(Int.MIN_VALUE.toLong())
internal val MAX_UINT64 = BigInteger("ffffffffffffffff", 16)
internal val MIN_INT64 = BigInteger.valueOf(Long.MIN_VALUE)
fun Byte.toUnsignedShort() = (this.toInt() and 0xff).toShort()

View File

@ -0,0 +1,26 @@
package asmble
import asmble.ast.SExpr
import asmble.ast.Script
import asmble.io.SExprToAst
import asmble.io.StrToSExpr
open class BaseTestUnit(val name: String, val wast: String, val expectedOutput: String?) {
override fun toString() = "Test unit: $name"
open val packageName = "asmble.temp." + name.replace('/', '.')
open val shouldFail get() = false
open val skipRunReason: String? get() = null
open val defaultMaxMemPages get() = 1
open val parseResult: StrToSExpr.ParseResult.Success by lazy {
StrToSExpr.parse(wast).let {
when (it) {
is StrToSExpr.ParseResult.Error -> throw Exception("$name[${it.pos}] Parse fail: ${it.msg}")
is StrToSExpr.ParseResult.Success -> it
}
}
}
open val ast: List<SExpr> get() = parseResult.vals
open val script: Script by lazy { SExprToAst.toScript(SExpr.Multi(ast)) }
open fun warningInsteadOfErrReason(t: Throwable): String? = null
}

View File

@ -1,34 +1,27 @@
package asmble
import asmble.ast.Node
import asmble.ast.SExpr
import asmble.ast.Script
import asmble.io.SExprToAst
import asmble.io.StrToSExpr
import asmble.run.jvm.ScriptAssertionError
import java.nio.file.FileSystems
import java.nio.file.Files
import java.nio.file.Paths
import java.util.stream.Collectors
class SpecTestUnit(val name: String, val wast: String, val expectedOutput: String?) {
class SpecTestUnit(name: String, wast: String, expectedOutput: String?) : BaseTestUnit(name, wast, expectedOutput) {
override fun toString() = "Spec unit: $name"
override val shouldFail get() = name.endsWith(".fail")
val shouldFail get() = name.endsWith(".fail")
val skipRunReason: String? get() = null
val defaultMaxMemPages get() = when (name) {
"nop"-> 20
"resizing" -> 830
override val defaultMaxMemPages get() = when (name) {
"nop" -> 20
"memory_grow" -> 830
"imports" -> 5
else -> 1
}
fun warningInsteadOfErrReason(t: Throwable) = when (name) {
override fun warningInsteadOfErrReason(t: Throwable) = when (name) {
// NaN bit patterns can be off
"float_literals", "float_exprs" ->
"float_literals", "float_exprs", "float_misc" ->
if (isNanMismatch(t)) "NaN JVM bit patterns can be off" else null
// We don't hold table capacity right now
// TODO: Figure out how we want to store/retrieve table capacity. Right now
@ -39,7 +32,7 @@ class SpecTestUnit(val name: String, val wast: String, val expectedOutput: Strin
// capacity since you lose speed.
"imports" -> {
val isTableMaxErr = t is ScriptAssertionError && (t.assertion as? Script.Cmd.Assertion.Unlinkable).let {
it != null && it.failure == "maximum size larger than declared" &&
it != null && it.failure == "incompatible import type" &&
it.module.imports.singleOrNull()?.kind is Node.Import.Kind.Table
}
if (isTableMaxErr) "Table max capacities are not validated" else null
@ -61,19 +54,6 @@ class SpecTestUnit(val name: String, val wast: String, val expectedOutput: Strin
else -> false
}
val parseResult: StrToSExpr.ParseResult.Success by lazy {
StrToSExpr.parse(wast).let {
when (it) {
is StrToSExpr.ParseResult.Error -> throw Exception("$name[${it.pos}] Parse fail: ${it.msg}")
is StrToSExpr.ParseResult.Success -> it
}
}
}
val ast: List<SExpr> get() = parseResult.vals
val script: Script by lazy { SExprToAst.toScript(SExpr.Multi(ast)) }
companion object {
val unitsPath = "/spec/test/core"

View File

@ -0,0 +1,9 @@
package asmble
import asmble.util.Logger
abstract class TestBase : Logger by TestBase.logger {
companion object {
val logger = Logger.Print(Logger.Level.INFO)
}
}

View File

@ -0,0 +1,33 @@
package asmble.ast
import asmble.SpecTestUnit
import asmble.TestBase
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class StackTest(val unit: SpecTestUnit) : TestBase() {
@Test
fun testStack() {
// If it's not a module expecting an error, we'll try to walk the stack on each function
unit.script.commands.mapNotNull { it as? Script.Cmd.Module }.forEach { mod ->
mod.module.funcs.filter { it.instructions.isNotEmpty() }.forEach { func ->
debug { "Func: ${func.type}" }
var indexCount = 0
Stack.walkStrict(mod.module, func) { stack, insn ->
debug { " After $insn (next: ${func.instructions.getOrNull(++indexCount)}, " +
"unreach depth: ${stack.unreachableUntilNextEndCount})" }
debug { " " + stack.current }
}
}
}
}
companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }//.filter { it.name == "loop" }
}
}

View File

@ -0,0 +1,89 @@
package asmble.ast.opt
import asmble.TestBase
import asmble.ast.Node
import asmble.compile.jvm.AstToAsm
import asmble.compile.jvm.ClsContext
import asmble.run.jvm.ScriptContext
import org.junit.Test
import java.nio.ByteBuffer
import java.util.*
import kotlin.test.assertEquals
class SplitLargeFuncTest : TestBase() {
@Test
fun testSplitLargeFunc() {
// We're going to make a large function that does some addition and then stores in mem
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
logger = logger,
mod = Node.Module(
memories = listOf(Node.Type.Memory(Node.ResizableLimits(initial = 2, maximum = 2))),
funcs = listOf(Node.Func(
type = Node.Type.Func(params = emptyList(), ret = null),
locals = emptyList(),
instructions = (0 until 501).flatMap {
listOf<Node.Instr>(
Node.Instr.I32Const(it * 4),
// Let's to i * (i = 1)
Node.Instr.I32Const(it),
Node.Instr.I32Const(it - 1),
Node.Instr.I32Mul,
Node.Instr.I32Store(0, 0)
)
}
)),
names = Node.NameSection(
moduleName = null,
funcNames = mapOf(0 to "someFunc"),
localNames = emptyMap()
),
exports = listOf(
Node.Export("memory", Node.ExternalKind.MEMORY, 0),
Node.Export("someFunc", Node.ExternalKind.FUNCTION, 0)
)
)
)
// Compile it
AstToAsm.fromModule(ctx)
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
val inst = cls.newInstance()
// Run someFunc
cls.getMethod("someFunc").invoke(inst)
// Get the memory out
val mem = cls.getMethod("getMemory").invoke(inst) as ByteBuffer
// Read out the mem values
(0 until 501).forEach { assertEquals(it * (it - 1), mem.getInt(it * 4)) }
// Now split it
val (splitMod, insnsSaved) = SplitLargeFunc.apply(ctx.mod, 0) ?: error("Nothing could be split")
// Count insns and confirm it is as expected
val origInsnCount = ctx.mod.funcs.sumBy { it.instructions.size }
val newInsnCount = splitMod.funcs.sumBy { it.instructions.size }
assertEquals(origInsnCount - newInsnCount, insnsSaved)
// Compile it
val splitCtx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
logger = logger,
mod = splitMod
)
AstToAsm.fromModule(splitCtx)
val splitCls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(splitCtx)
val splitInst = splitCls.newInstance()
// Run someFunc
splitCls.getMethod("someFunc").invoke(splitInst)
// Get the memory out and compare it
val splitMem = splitCls.getMethod("getMemory").invoke(splitInst) as ByteBuffer
assertEquals(mem, splitMem)
// Dump some info
logger.debug {
val orig = ctx.mod.funcs.first()
val (new, split) = splitMod.funcs.let { it.first() to it.last() }
"Split complete, from single func with ${orig.instructions.size} insns to func " +
"with ${new.instructions.size} insns + delegated func " +
"with ${split.instructions.size} insns and ${split.type.params.size} params"
}
}
}

View File

@ -0,0 +1,45 @@
package asmble.compile.jvm
import asmble.TestBase
import asmble.ast.Node
import asmble.run.jvm.ScriptContext
import asmble.util.get
import org.junit.Test
import java.nio.ByteBuffer
import java.util.*
import kotlin.test.assertEquals
class LargeDataTest : TestBase() {
@Test
fun testLargeData() {
// This previously failed because string constants can't be longer than 65536 chars.
// We create a byte array across the whole gambit of bytes to test UTF8 encoding.
val bytesExpected = ByteArray(70000) { ((it % 255) - Byte.MIN_VALUE).toByte() }
val mod = Node.Module(
memories = listOf(Node.Type.Memory(
limits = Node.ResizableLimits(initial = 2, maximum = 2)
)),
data = listOf(Node.Data(
index = 0,
offset = listOf(Node.Instr.I32Const(0)),
data = bytesExpected
))
)
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
mod = mod,
logger = logger
)
AstToAsm.fromModule(ctx)
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
// Instantiate it, get the memory out, and check it
val field = cls.getDeclaredField("memory").apply { isAccessible = true }
val buf = field[cls.newInstance()] as ByteBuffer
// Grab all + 1 and check values
val bytesActual = ByteArray(70001).also { buf.get(0, it) }
bytesActual.forEachIndexed { index, byte ->
assertEquals(if (index == 70000) 0.toByte() else bytesExpected[index], byte)
}
}
}

View File

@ -0,0 +1,38 @@
package asmble.compile.jvm
import asmble.TestBase
import asmble.io.SExprToAst
import asmble.io.StrToSExpr
import asmble.run.jvm.ScriptContext
import org.junit.Test
import java.util.*
class NamesTest : TestBase() {
@Test
fun testNames() {
// Compile and make sure the names are set right
val (_, mod) = SExprToAst.toModule(StrToSExpr.parseSingleMulti("""
(module ${'$'}mod_name
(import "foo" "bar" (func ${'$'}import_func (param i32)))
(type ${'$'}some_sig (func (param ${'$'}type_param i32)))
(func ${'$'}some_func
(type ${'$'}some_sig)
(param ${'$'}func_param i32)
(local ${'$'}func_local0 i32)
(local ${'$'}func_local1 f64)
)
)
""".trimIndent()))
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
mod = mod,
logger = logger
)
AstToAsm.fromModule(ctx)
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
// Make sure the import field and the func are present named
cls.getDeclaredField("import_func")
cls.getDeclaredMethod("some_func", Integer.TYPE)
}
}

View File

@ -2,12 +2,9 @@ package asmble.io
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.PipedInputStream
import java.io.PipedOutputStream
import java.math.BigInteger
import java.nio.ByteBuffer
import kotlin.test.assertEquals
import kotlin.test.assertFails
import kotlin.test.assertFalse

View File

@ -1,6 +1,7 @@
package asmble.io
import asmble.SpecTestUnit
import asmble.TestBase
import asmble.ast.Node
import asmble.ast.Script
import asmble.util.Logger
@ -13,12 +14,10 @@ import java.io.ByteArrayOutputStream
import kotlin.test.assertEquals
@RunWith(Parameterized::class)
class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) {
class IoTest(val unit: SpecTestUnit) : TestBase() {
@Test
fun testIo() {
// Ignore things that are supposed to fail
if (unit.shouldFail) return
// Go from the AST to binary then back to AST then back to binary and confirm values are as expected
val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also {
trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) }
@ -46,7 +45,8 @@ class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO)
}
companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }
}
}

View File

@ -0,0 +1,47 @@
package asmble.io
import asmble.ast.Node
import org.junit.Test
import kotlin.test.assertEquals
class NamesTest {
@Test
fun testNames() {
// First, make sure it can parse from sexpr
val (_, mod1) = SExprToAst.toModule(StrToSExpr.parseSingleMulti("""
(module ${'$'}mod_name
(import "foo" "bar" (func ${'$'}import_func (param i32)))
(type ${'$'}some_sig (func (param ${'$'}type_param i32)))
(func ${'$'}some_func
(type ${'$'}some_sig)
(param ${'$'}func_param i32)
(local ${'$'}func_local0 i32)
(local ${'$'}func_local1 f64)
)
)
""".trimIndent()))
val expected = Node.NameSection(
moduleName = "mod_name",
funcNames = mapOf(
0 to "import_func",
1 to "some_func"
),
localNames = mapOf(
1 to mapOf(
0 to "func_param",
1 to "func_local0",
2 to "func_local1"
)
)
)
assertEquals(expected, mod1.names)
// Now back to binary and then back and make sure it's still there
val bytes = AstToBinary.fromModule(mod1)
val mod2 = BinaryToAst.toModule(bytes)
assertEquals(expected, mod2.names)
// Now back to sexpr and then back to make sure the sexpr writer works
val sexpr = AstToSExpr.fromModule(mod2)
val (_, mod3) = SExprToAst.toModule(sexpr)
assertEquals(expected, mod3.names)
}
}

View File

@ -0,0 +1,13 @@
package asmble.run.jvm
import asmble.SpecTestUnit
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class RunTest(unit: SpecTestUnit) : TestRunner<SpecTestUnit>(unit) {
companion object {
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits
}
}

View File

@ -0,0 +1,73 @@
package asmble.run.jvm
import asmble.BaseTestUnit
import asmble.TestBase
import asmble.annotation.WasmModule
import asmble.io.AstToBinary
import asmble.io.AstToSExpr
import asmble.io.ByteWriter
import asmble.io.SExprToStr
import org.junit.Assume
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.io.OutputStreamWriter
import java.io.PrintWriter
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
abstract class TestRunner<out T : BaseTestUnit>(val unit: T) : TestBase() {
@Test
fun test() {
unit.skipRunReason?.let { Assume.assumeTrue("Skipping ${unit.name}, reason: $it", false) }
val ex = try { run(); null } catch (e: Throwable) { e }
if (unit.shouldFail) {
assertNotNull(ex, "Expected failure, but succeeded")
debug { "Got expected failure: $ex" }
} else if (ex != null) throw ex
}
private fun run() {
debug { "AST SExpr: " + unit.ast }
debug { "AST Str: " + SExprToStr.fromSExpr(*unit.ast.toTypedArray()) }
debug { "AST: " + unit.script }
debug { "AST Str: " + SExprToStr.fromSExpr(*AstToSExpr.fromScript(unit.script).toTypedArray()) }
val out = ByteArrayOutputStream()
var scriptContext = ScriptContext(
packageName = unit.packageName,
logger = this,
adjustContext = { it.copy(eagerFailLargeMemOffset = false) },
defaultMaxMemPages = unit.defaultMaxMemPages,
// Include the binary data so we can check it later
includeBinaryInCompiledClass = true
).withHarnessRegistered(PrintWriter(OutputStreamWriter(out, Charsets.UTF_8), true))
// This will fail assertions as necessary
scriptContext = unit.script.commands.fold(scriptContext) { scriptContext, cmd ->
try {
scriptContext.runCommand(cmd)
} catch (t: Throwable) {
val warningReason = unit.warningInsteadOfErrReason(t) ?: throw t
warn { "Unexpected error on ${unit.name}, but is a warning. Reason: $warningReason. Orig err: $t" }
scriptContext
}
}
// Check the output
unit.expectedOutput?.let {
// Sadly, sometimes the expected output is trimmed in Emscripten tests
assertEquals(it.trimEnd(), out.toByteArray().toString(Charsets.UTF_8).trimEnd())
}
// Also check the annotations
scriptContext.modules.forEach { mod ->
val expectedBinaryString = ByteArrayOutputStream().also {
ByteWriter.OutputStream(it).also { AstToBinary.fromModule(it, mod.mod) }
}.toByteArray().toString(Charsets.ISO_8859_1)
val actualBinaryString =
mod.cls.getDeclaredAnnotation(WasmModule::class.java)?.binary ?: error("No annotation")
assertEquals(expectedBinaryString, actualBinaryString)
}
}
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,49 @@
(module
(memory 1)
(global $foo (mut i32) (i32.const 20))
(global $bar (mut f32) (f32.const 0))
;; This was breaking because stack diff was wrong for get_global and set_global
(func (export "testGlobals") (param $p i32) (result i32)
(local i32)
(get_global $foo)
(set_local 1)
(get_global $foo)
(get_local $p)
(i32.add)
(set_global $foo)
(get_global $foo)
(i32.const 15)
(i32.add)
(i32.const -16)
(i32.and)
(set_global $foo)
(get_global $foo)
)
;; Sqrt had bad stack diff
(func (export "testSqrt") (param $p f32) (result f32)
(set_global $bar (f32.sqrt (get_local $p)))
(get_global $bar)
)
;; Conditionals w/ different load counts had bad stack diff
(func (export "testConditional") (param $p i32) (result i32)
(get_local $p)
(if (result i32) (get_local $p)
(then (i32.load (get_local $p)))
(else
(i32.add
(i32.load (get_local $p))
(i32.load (get_local $p))
)
)
)
(i32.store)
(i32.load (get_local $p))
)
)
(assert_return (invoke "testGlobals" (i32.const 7)) (i32.const 32))
(assert_return (invoke "testSqrt" (f32.const 144)) (f32.const 12))

11
examples/README.md Normal file
View File

@ -0,0 +1,11 @@
## Examples
Below are some examples.
### Rust
Compile Rust to WASM and then to the JVM. In order of complexity:
* [rust-simple](rust-simple)
* [rust-string](rust-string)
* [rust-regex](rust-regex)

View File

@ -0,0 +1,10 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Printf("Args: %v", os.Args)
}

View File

@ -0,0 +1,2 @@
[build]
target = "wasm32-unknown-unknown"

View File

@ -0,0 +1,9 @@
[package]
name = "rust_regex"
version = "0.1.0"
[lib]
crate-type = ["cdylib"]
[dependencies]
regex = "0.2"

View File

@ -0,0 +1,140 @@
### Example: Rust Regex
This shows an example of using the Rust regex library on the JVM compiled via WASM. This builds on
the [rust-simple](../rust-simple) and [rust-string](../rust-string) examples. See the former for build prereqs. There is
also a simple benchmark checking the performance compared to the built-in Java regex engine.
#### Main
In this version, we include the `regex` crate. The main loads a ~15MB text file Project Gutenberg collection of Mark
Twain works (taken from [this blog post](https://rust-leipzig.github.io/regex/2017/03/28/comparison-of-regex-engines/)
that does Rust regex performance benchmarks). Both the Java and Rust regex engines are abstracted into a common
interface. When run, it checks how many times the word "Twain" appears via both regex engines.
To run it yourself, run the following from the root `asmble` dir:
gradlew --no-daemon :examples:rust-regex:run
In release mode, the generated class is 903KB w/ ~575 methods. The output:
'Twain' count in Java: 811
'Twain' count in Rust: 811
#### Tests
I wanted to compare the Java regex engine with the Rust regex engine. Before running benchmarks, I wrote a
[unit test](src/test/java/asmble/examples/rustregex/RegexTest.java) to test parity. I used the examples from the
aforementioned [blog post](https://rust-leipzig.github.io/regex/2017/03/28/comparison-of-regex-engines/) to test with.
The test simply confirms the Java regex library and the Rust regex library produce the same match counts across the
Mark Twain corpus. To run the test, execute:
gradlew --no-daemon :examples:rust-regex:test
Here is my output of the test part:
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: Twain] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: (?i)Twain] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: [a-z]shing] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: Huck[a-zA-Z]+|Saw[a-zA-Z]+] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: \b\w+nn\b] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: [a-q][^u-z]{13}x] SKIPPED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: Tom|Sawyer|Huckleberry|Finn] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: (?i)Tom|Sawyer|Huckleberry|Finn] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: .{0,2}(Tom|Sawyer|Huckleberry|Finn)] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: .{2,4}(Tom|Sawyer|Huckleberry|Finn)] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: Tom.{10,25}river|river.{10,25}Tom] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: [a-zA-Z]+ing] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: \s[a-zA-Z]{0,12}ing\s] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: ([A-Za-z]awyer|[A-Za-z]inn)\s] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: ["'][^"']{0,30}[?!\.]["']] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: ?|?] PASSED
asmble.examples.rustregex.RegexTest > checkJavaVersusRust[pattern: \p{Sm}] PASSED
As mentioned in the blog post, `[a-q][^u-z]{13}x` is a very slow pattern for Rust, so I skipped it (but it does produce
the same count if you're willing to wait a couple of minutes). Also, `?|?` is actually `∞|✓`, it's just not printable
unicode in the text output I used.
#### Benchmarks
With the accuracy confirmed, now was time to benchmark the two engines. I wrote a
[JMH benchmark](src/jmh/java/asmble/examples/rustregex/RegexBenchmark.java) to test the same patterns as the unit test
checks. It precompiles the patterns and preloads the target string on the Rust side before checking simple match count.
As with any benchmarks, this is just my empirical data and everyone else's will be different. To run the benchmark,
execute (it takes a while to run):
gradlew --no-daemon :examples:rust-regex:jmh
Here are my results (reordered and with added linebreaks for readability, higher score is better):
Benchmark (patternString) Mode Cnt Score Error Units
RegexBenchmark.javaRegexCheck Twain thrpt 15 29.756 ± 1.169 ops/s
RegexBenchmark.rustRegexCheck Twain thrpt 15 55.012 ± 0.677 ops/s
RegexBenchmark.javaRegexCheck (?i)Twain thrpt 15 6.181 ± 0.560 ops/s
RegexBenchmark.rustRegexCheck (?i)Twain thrpt 15 1.333 ± 0.029 ops/s
RegexBenchmark.javaRegexCheck [a-z]shing thrpt 15 6.138 ± 0.937 ops/s
RegexBenchmark.rustRegexCheck [a-z]shing thrpt 15 12.352 ± 0.103 ops/s
RegexBenchmark.javaRegexCheck Huck[a-zA-Z]+|Saw[a-zA-Z]+ thrpt 15 4.774 ± 0.330 ops/s
RegexBenchmark.rustRegexCheck Huck[a-zA-Z]+|Saw[a-zA-Z]+ thrpt 15 56.079 ± 0.487 ops/s
RegexBenchmark.javaRegexCheck \b\w+nn\b thrpt 15 2.703 ± 0.086 ops/s
RegexBenchmark.rustRegexCheck \b\w+nn\b thrpt 15 0.131 ± 0.001 ops/s
RegexBenchmark.javaRegexCheck Tom|Sawyer|Huckleberry|Finn thrpt 15 2.633 ± 0.033 ops/s
RegexBenchmark.rustRegexCheck Tom|Sawyer|Huckleberry|Finn thrpt 15 14.388 ± 0.138 ops/s
RegexBenchmark.javaRegexCheck (?i)Tom|Sawyer|Huckleberry|Finn thrpt 15 3.178 ± 0.045 ops/s
RegexBenchmark.rustRegexCheck (?i)Tom|Sawyer|Huckleberry|Finn thrpt 15 8.882 ± 0.110 ops/s
RegexBenchmark.javaRegexCheck .{0,2}(Tom|Sawyer|Huckleberry|Finn) thrpt 15 1.191 ± 0.010 ops/s
RegexBenchmark.rustRegexCheck .{0,2}(Tom|Sawyer|Huckleberry|Finn) thrpt 15 0.572 ± 0.012 ops/s
RegexBenchmark.javaRegexCheck .{2,4}(Tom|Sawyer|Huckleberry|Finn) thrpt 15 1.017 ± 0.024 ops/s
RegexBenchmark.rustRegexCheck .{2,4}(Tom|Sawyer|Huckleberry|Finn) thrpt 15 0.584 ± 0.008 ops/s
RegexBenchmark.javaRegexCheck Tom.{10,25}river|river.{10,25}Tom thrpt 15 5.326 ± 0.050 ops/s
RegexBenchmark.rustRegexCheck Tom.{10,25}river|river.{10,25}Tom thrpt 15 15.705 ± 0.247 ops/s
RegexBenchmark.javaRegexCheck [a-zA-Z]+ing thrpt 15 1.768 ± 0.057 ops/s
RegexBenchmark.rustRegexCheck [a-zA-Z]+ing thrpt 15 1.001 ± 0.012 ops/s
RegexBenchmark.javaRegexCheck \s[a-zA-Z]{0,12}ing\s thrpt 15 4.020 ± 0.111 ops/s
RegexBenchmark.rustRegexCheck \s[a-zA-Z]{0,12}ing\s thrpt 15 0.416 ± 0.004 ops/s
RegexBenchmark.javaRegexCheck ([A-Za-z]awyer|[A-Za-z]inn)\s thrpt 15 2.441 ± 0.024 ops/s
RegexBenchmark.rustRegexCheck ([A-Za-z]awyer|[A-Za-z]inn)\s thrpt 15 0.591 ± 0.004 ops/s
RegexBenchmark.javaRegexCheck ["'][^"']{0,30}[?!\.]["'] thrpt 15 20.466 ± 0.309 ops/s
RegexBenchmark.rustRegexCheck ["'][^"']{0,30}[?!\.]["'] thrpt 15 2.459 ± 0.024 ops/s
RegexBenchmark.javaRegexCheck ?|? thrpt 15 15.856 ± 0.158 ops/s
RegexBenchmark.rustRegexCheck ?|? thrpt 15 14.657 ± 0.177 ops/s
RegexBenchmark.javaRegexCheck \p{Sm} thrpt 15 22.156 ± 0.406 ops/s
RegexBenchmark.rustRegexCheck \p{Sm} thrpt 15 0.592 ± 0.005 ops/s
To keep from making this a big long post like most benchmark posts tend to be, here is a bulleted list of notes:
* I ran this on a Win 10 box, 1.8GHz i7-8550U HP laptop. I used latest Zulu JDK 8. For JMH, I set it at 3 forks, 5
warmup iterations, and 5 measurement iterations (that's why `cnt` above is 15 = 5 measurements * 3 forks). It took a
bit over 25 minutes to complete.
* All of the tests had the Java and Rust patterns precompiled. In Rust's case, I also placed the UTF-8 string on the
accessible heap before the benchmark started to be fair.
* Like the unit test, I excluded `[a-q][^u-z]{13}x` because Rust is really slow at it (Java wins by a mile here). Also
like the unit test, `?|?` is actually `∞|✓`.
* Of the ones tested, Rust is faster in 6 and Java is faster in the other 10. And where Rust is faster, it is much
faster. This is quite decent since the Rust+WASM version uses `ByteBuffer`s everywhere, has some overflow checks, and
in general there are some impedance mismatches with the WASM bytecode and the JVM bytecode.
* Notice the low error numbers on the Rust versions. The error number is the deviation between invocations. This shows
the WASM-to-JVM ends up quite deterministic (or maybe, that there is just too much cruft to JIT, heh).
* If I were more serious about it, I'd check with other OS's, add more iterations, tweak some compiler options, include
regex pattern compilation speed benchmarks, and so on. But I just needed simple proof that speed is reasonable.
Overall, this shows running Rust on the JVM to be entirely reasonable for certain types of workloads. There are still
memory concerns, but not terribly. If given the choice, use a JVM language of course; the safety benefits of Rust don't
outweigh the problems of Rust-to-WASM-to-JVM such as build complexity, security concerns (`ByteBuffer` is where all
memory lives), debugging concerns, etc. But if you have a library in Rust, exposing it to the JVM sans-JNI is a doable
feat if you must.

View File

@ -0,0 +1,61 @@
package asmble.examples.rustregex;
import org.openjdk.jmh.annotations.*;
import java.io.IOException;
@State(Scope.Thread)
public class RegexBenchmark {
@Param({
"Twain",
"(?i)Twain",
"[a-z]shing",
"Huck[a-zA-Z]+|Saw[a-zA-Z]+",
"\\b\\w+nn\\b",
// Too slow
// "[a-q][^u-z]{13}x",
"Tom|Sawyer|Huckleberry|Finn",
"(?i)Tom|Sawyer|Huckleberry|Finn",
".{0,2}(Tom|Sawyer|Huckleberry|Finn)",
".{2,4}(Tom|Sawyer|Huckleberry|Finn)",
"Tom.{10,25}river|river.{10,25}Tom",
"[a-zA-Z]+ing",
"\\s[a-zA-Z]{0,12}ing\\s",
"([A-Za-z]awyer|[A-Za-z]inn)\\s",
"[\"'][^\"']{0,30}[?!\\.][\"']",
"\u221E|\u2713",
"\\p{Sm}"
})
private String patternString;
private String twainString;
private JavaLib javaLib;
private JavaLib.JavaPattern precompiledJavaPattern;
private RustLib rustLib;
private RustLib.Ptr preparedRustTarget;
private RustLib.RustPattern precompiledRustPattern;
@Setup
public void init() throws IOException {
// JMH is not handling this right, so we replace inline
if ("?|?".equals(patternString)) {
patternString = "\u221E|\u2713";
}
twainString = Main.loadTwainText();
javaLib = new JavaLib();
precompiledJavaPattern = javaLib.compile(patternString);
rustLib = new RustLib();
preparedRustTarget = rustLib.prepareTarget(twainString);
precompiledRustPattern = rustLib.compile(patternString);
}
@Benchmark
public void javaRegexCheck() {
precompiledJavaPattern.matchCount(twainString);
}
@Benchmark
public void rustRegexCheck() {
precompiledRustPattern.matchCount(preparedRustTarget);
}
}

View File

@ -0,0 +1,54 @@
#![feature(allocator_api)]
extern crate regex;
use std::ptr::NonNull;
use regex::Regex;
use std::alloc::{Alloc, Global, Layout};
use std::mem;
use std::str;
#[no_mangle]
pub extern "C" fn compile_pattern(str_ptr: *mut u8, len: usize) -> *mut Regex {
unsafe {
let bytes = Vec::<u8>::from_raw_parts(str_ptr, len, len);
let s = str::from_utf8_unchecked(&bytes);
let r = Box::new(Regex::new(s).unwrap());
Box::into_raw(r)
}
}
#[no_mangle]
pub extern "C" fn dispose_pattern(r: *mut Regex) {
unsafe {
let _r = Box::from_raw(r);
}
}
#[no_mangle]
pub extern "C" fn match_count(r: *mut Regex, str_ptr: *mut u8, len: usize) -> usize {
unsafe {
let bytes = Vec::<u8>::from_raw_parts(str_ptr, len, len);
let s = str::from_utf8_unchecked(&bytes);
let r = Box::from_raw(r);
let count = r.find_iter(s).count();
mem::forget(r);
count
}
}
#[no_mangle]
pub extern "C" fn alloc(size: usize) -> NonNull<u8> {
unsafe {
let layout = Layout::from_size_align(size, mem::align_of::<u8>()).unwrap();
Global.alloc(layout).unwrap()
}
}
#[no_mangle]
pub extern "C" fn dealloc(ptr: NonNull<u8>, size: usize) {
unsafe {
let layout = Layout::from_size_align(size, mem::align_of::<u8>()).unwrap();
Global.dealloc(ptr, layout);
}
}

View File

@ -0,0 +1,42 @@
package asmble.examples.rustregex;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Implementation of {@link RegexLib} based on `java.util.regex`.
*/
public class JavaLib implements RegexLib<String> {
@Override
public JavaPattern compile(String str) {
return new JavaPattern(str);
}
@Override
public String prepareTarget(String target) {
return target;
}
public class JavaPattern implements RegexPattern<String> {
private final Pattern pattern;
private JavaPattern(String pattern) {
this(Pattern.compile(pattern));
}
private JavaPattern(Pattern pattern) {
this.pattern = pattern;
}
@Override
public int matchCount(String target) {
Matcher matcher = pattern.matcher(target);
int count = 0;
while (matcher.find()) {
count++;
}
return count;
}
}
}

View File

@ -0,0 +1,36 @@
package asmble.examples.rustregex;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
public class Main {
public static void main(String[] args) throws Exception {
String twainString = loadTwainText();
System.out.println("'Twain' count in Java: " + matchCount(twainString, "Twain", new JavaLib()));
System.out.println("'Twain' count in Rust: " + matchCount(twainString, "Twain", new RustLib()));
}
public static <T> int matchCount(String target, String pattern, RegexLib<T> lib) {
RegexLib.RegexPattern<T> compiledPattern = lib.compile(pattern);
T preparedTarget = lib.prepareTarget(target);
return compiledPattern.matchCount(preparedTarget);
}
public static String loadTwainText() throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
try (InputStream is = Main.class.getResourceAsStream("/twain-for-regex.txt")) {
byte[] buffer = new byte[0xFFFF];
while (true) {
int lastLen = is.read(buffer);
if (lastLen < 0) {
break;
}
os.write(buffer, 0, lastLen);
}
}
return new String(os.toByteArray(), StandardCharsets.ISO_8859_1);
}
}

View File

@ -0,0 +1,12 @@
package asmble.examples.rustregex;
public interface RegexLib<T> {
RegexPattern<T> compile(String str);
T prepareTarget(String target);
interface RegexPattern<T> {
int matchCount(T target);
}
}

View File

@ -0,0 +1,87 @@
package asmble.examples.rustregex;
import asmble.generated.RustRegex;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
/**
* Implementation of {@link RegexLib} based on `asmble.generated.RustRegex` that
* was composed from Rust code (see lib.rs).
*/
public class RustLib implements RegexLib<RustLib.Ptr> {
// 600 pages is enough for our use
private static final int PAGE_SIZE = 65536;
private static final int MAX_MEMORY = 600 * PAGE_SIZE;
private final RustRegex rustRegex;
public RustLib() {
rustRegex = new RustRegex(MAX_MEMORY);
}
@Override
public RustPattern compile(String str) {
return new RustPattern(str);
}
@Override
public Ptr prepareTarget(String target) {
return ptrFromString(target);
}
private Ptr ptrFromString(String str) {
byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
Ptr ptr = new Ptr(bytes.length);
ptr.put(bytes);
return ptr;
}
public class RustPattern implements RegexPattern<Ptr> {
private final int pointer;
private RustPattern(String pattern) {
Ptr ptr = ptrFromString(pattern);
pointer = rustRegex.compile_pattern(ptr.offset, ptr.size);
}
@Override
protected void finalize() throws Throwable {
rustRegex.dispose_pattern(pointer);
}
@Override
public int matchCount(Ptr target) {
return rustRegex.match_count(pointer, target.offset, target.size);
}
}
public class Ptr {
final int offset;
final int size;
Ptr(int offset, int size) {
this.offset = offset;
this.size = size;
}
Ptr(int size) {
this(rustRegex.alloc(size), size);
}
void put(byte[] bytes) {
// Yeah, yeah, not thread safe
ByteBuffer memory = rustRegex.getMemory();
memory.position(offset);
memory.put(bytes);
}
@Override
protected void finalize() throws Throwable {
rustRegex.dealloc(offset, size);
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,67 @@
package asmble.examples.rustregex;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.IOException;
@RunWith(Parameterized.class)
public class RegexTest {
// Too slow to run regularly
private static final String TOO_SLOW = "[a-q][^u-z]{13}x";
@Parameterized.Parameters(name = "pattern: {0}")
public static String[] data() {
return new String[] {
"Twain",
"(?i)Twain",
"[a-z]shing",
"Huck[a-zA-Z]+|Saw[a-zA-Z]+",
"\\b\\w+nn\\b",
"[a-q][^u-z]{13}x",
"Tom|Sawyer|Huckleberry|Finn",
"(?i)Tom|Sawyer|Huckleberry|Finn",
".{0,2}(Tom|Sawyer|Huckleberry|Finn)",
".{2,4}(Tom|Sawyer|Huckleberry|Finn)",
"Tom.{10,25}river|river.{10,25}Tom",
"[a-zA-Z]+ing",
"\\s[a-zA-Z]{0,12}ing\\s",
"([A-Za-z]awyer|[A-Za-z]inn)\\s",
"[\"'][^\"']{0,30}[?!\\.][\"']",
"\u221E|\u2713",
"\\p{Sm}"
};
}
private static RustLib rustLib;
private static String twainText;
private static RustLib.Ptr preparedRustTarget;
@BeforeClass
public static void setUpClass() throws IOException {
twainText = Main.loadTwainText();
rustLib = new RustLib();
preparedRustTarget = rustLib.prepareTarget(twainText);
}
private String pattern;
public RegexTest(String pattern) {
this.pattern = pattern;
}
@Test
public void checkJavaVersusRust() {
Assume.assumeFalse("Skipped for being too slow", pattern.equals(TOO_SLOW));
int expected = new JavaLib().compile(pattern).matchCount(twainText);
// System.out.println("Found " + expected + " matches for pattern: " + pattern);
Assert.assertEquals(
expected,
rustLib.compile(pattern).matchCount(preparedRustTarget)
);
}
}

View File

@ -0,0 +1,2 @@
[build]
target = "wasm32-unknown-unknown"

View File

@ -0,0 +1,6 @@
[package]
name = "rust_simple"
version = "0.1.0"
[lib]
crate-type = ["cdylib"]

Some files were not shown because too many files have changed in this diff Show More