60 Commits

Author SHA1 Message Date
vms
4c65740d03 initial commit (#13) 2019-08-10 14:52:25 +03:00
vms
728b78d713 add EIC metering (#12) 2019-08-06 17:25:48 +03:00
vms
cb907ae2da Env module for gas metering (#11) 2019-08-06 10:45:16 +03:00
1d6002624f Fix memory builder (#10)
* fix memory instance creation

* change version
2019-06-03 14:02:15 +03:00
ad2b7c071f Bytebuffer abstraction (#9) 2019-05-13 16:40:07 +03:00
vms
119ce58c9e add using optional module name to registred names (#8) 2019-03-29 19:01:49 +03:00
vms
1323e02c95 fix bintray user and key properties absence (#4) 2018-11-19 15:25:31 +03:00
9172fba948 Add posibillity to add a Logger Wasm module (#3)
* Add logger module

* Add LoggerModuleTest
2018-11-15 12:38:50 +04:00
b9b45cf997 Merge pull request #2 from fluencelabs/logger
Fix logger and return C example
2018-11-09 13:47:41 +04:00
2bfa39a3c6 Tweaking after merge 2018-11-09 10:30:38 +04:00
317b608048 Merge fix for late init for logger 2018-11-09 10:28:47 +04:00
21b023f1c6 Return C example and skip it by default 2018-11-09 10:18:49 +04:00
765d8b4dba Possibility to skip examples 2018-11-09 10:03:44 +04:00
58cf836b76 Disable Go and Rust examples and up project version 2018-10-01 13:40:38 +04:00
56c2c8d672 Fix lateinit error with logger for wasm files 2018-10-01 13:39:08 +04:00
fb0be9d31a Merge remote-tracking branch 'upstream/master'
# Conflicts:
#	compiler/src/test/resources/spec
2018-09-18 15:56:16 +04:00
d1f48aaaa0 Replace previous large-method-split attempt with msplit-based one for issue #19 2018-09-13 16:50:48 -05:00
46a8ce3f52 Updated to latest spec and minor fix on block insn insertion count 2018-09-13 13:10:11 -05:00
6352efaa96 Update Kotlin and ASM 2018-09-12 16:01:46 -05:00
326a0cdaba Change Go example to simple hello world 2018-09-12 15:49:49 -05:00
da70c9fca4 Fix for "lateinit property logger has not been initialized" 2018-08-28 16:50:37 +04:00
a1a5563367 Publishing the fork to binTray 2018-08-14 17:32:10 +04:00
e489e7c889 Publishing the fork to binTray 2018-08-14 17:22:30 +04:00
c1391b2701 Expected but isn't working version 2018-08-14 12:11:48 +04:00
6b28c5a93b merge original-master to master 2018-08-10 13:10:28 +04:00
9dc4112512 Add beginning of CLI for func split issue #19 2018-07-29 01:18:39 -05:00
01d89947e5 Complete main func splitting impl with test 2018-07-28 23:28:50 -05:00
4373676448 Begin test for large split func on issue #19 2018-07-27 17:03:55 -05:00
94953a4ada More work on large func splitter for issue #19 2018-07-27 15:16:37 -05:00
472579020a Early code for WASM function splitting for issue #19 2018-07-27 02:56:30 -05:00
75de1d76e3 Finish stack walker 2018-07-27 01:15:48 -05:00
67e914d683 Begin abstracting stack walking for issue #19 2018-07-26 17:17:43 -05:00
4bc12f4b94 Begin work on Go examples for #14 2018-07-26 17:16:11 -05:00
1990f46743 merge asmble-master to master 2018-07-26 10:49:28 +04:00
559df45f09 Merge pull request #1 from cretz/master
Fetch master changes
2018-07-26 10:34:42 +04:00
73862e9bc9 Chunk data sections even smaller per #18 and update README explanation 2018-07-26 00:27:17 -05:00
1127b61eb5 Update README explanation about string const max 2018-07-26 00:08:40 -05:00
a66c05ad4a Support large data sections. Fixes #18 2018-07-26 00:05:18 -05:00
6786350f53 Fixed to set proper stack diff size for store insns 2018-07-26 00:03:49 -05:00
706da0d486 Removed prefixed dollar sign from sexpr names and add export names to dedupe check for issue #17 2018-07-25 16:30:48 -05:00
1d5c1e527a Emit given names in compiled class. Fixes #17 2018-07-25 15:59:27 -05:00
1430bf48a6 Support names in converters for issue #17 2018-07-25 15:19:25 -05:00
96febbecd5 Beginning of name support for issue #17 2018-07-25 12:57:54 -05:00
80a8a1fbb9 Temporary disable rust-regex example 2018-07-25 10:35:37 +04:00
dd72c7124c Fix Rust Simple and Rust String examples 2018-07-25 10:21:38 +04:00
c04a3c4a9b Add some java docs 2018-07-24 11:00:41 +04:00
51520ac07d Add some java docs 2018-07-23 12:52:30 +04:00
3c25b40c40 Maven fetch instructions 2018-07-20 16:15:26 -05:00
96458bdec7 Maven publishing support. Fixes #15 2018-07-20 15:59:03 -05:00
dd33676e50 README update for mutable globals issue #16 2018-07-20 15:05:57 -05:00
e9364574a3 Changed resizing to memory_grow for default max mem pages due to https://github.com/WebAssembly/spec/pull/808 2018-07-20 14:59:04 -05:00
73e6b5769a Support mutable globals. Fixes #16 2018-07-20 14:43:54 -05:00
9d87ce440f Add "too many locals" error 2018-07-20 10:20:03 -05:00
51bc8008e1 Update to latest spec 2018-07-20 10:13:27 -05:00
198c521dd7 Update Kotlin and minor README tweak 2018-07-20 09:23:27 -05:00
97660de6ba Add some java docs 2018-07-20 12:52:21 +04:00
cee7a86773 Fix rust-regex example 2018-07-19 17:02:25 +04:00
cfa4a35af1 remove C example 2018-07-19 09:14:01 +04:00
f24342959d Update spec and two related changes (detailed)
* Added LEB128 validation (ref: https://github.com/WebAssembly/spec/pull/750)
* Rename memory instructions (ref: https://github.com/WebAssembly/spec/pull/720)
2018-05-07 15:39:31 -05:00
368ab300fa Update Kotlin 2018-05-07 11:43:34 -05:00
64 changed files with 2658 additions and 372 deletions

2
.gitignore vendored
View File

@ -16,6 +16,8 @@
/annotations/out
/examples/c-simple/bin
/examples/c-simple/build
/examples/go-simple/bin
/examples/go-simple/build
/examples/rust-simple/Cargo.lock
/examples/rust-simple/bin
/examples/rust-simple/build

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2017 Chad Retz
Copyright (c) 2018 Chad Retz
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -183,9 +183,16 @@ JVM languages.
### Getting
The latest tag can be added to your build script via [JitPack](https://jitpack.io). For example,
[here](https://jitpack.io/#cretz/asmble/0.1.0) are instructions for using the 0.1.0 release and
[here](https://jitpack.io/#cretz/asmble/master-SNAPSHOT) are instructions for the latest master.
The compiler and annotations are deployed to Maven Central. The compiler is written in Kotlin and can be added as a
Gradle dependency with:
compile 'com.github.cretz.asmble:asmble-compiler:0.3.0'
This is only needed to compile of course, the compiled code has no runtime requirement. The compiled code does include
some annotations (but in Java its ok to have annotations that are not found). If you do want to reflect the annotations,
the annotation library can be added as a Gradle dependency with:
compile 'com.github.cretz.asmble:asmble-annotations:0.3.0'
### Building and Testing
@ -256,15 +263,16 @@ In the WebAssembly MVP a table is just a set of function pointers. This is store
#### Globals
Globals are stored as fields on the class. A non-import global is simply a field, but an import global is a
`MethodHandle` to the getter (and would be a `MethodHandle` to the setter if mutable globals were supported). Any values
for the globals are set in the constructor.
Globals are stored as fields on the class. A non-import global is simply a field that is final if not mutable. An import
global is a `MethodHandle` to the getter and a `MethodHandle` to the setter if mutable. Any values for the globals are
set in the constructor.
#### Imports
The constructor accepts all imports as params. Memory is imported via a `ByteBuffer` param, then function
imports as `MethodHandle` params, then global imports as `MethodHandle` params, then a `MethodHandle` array param for an
imported table. All of these values are set as fields in the constructor.
imports as `MethodHandle` params, then global imports as `MethodHandle` params (one for getter and another for setter if
mutable), then a `MethodHandle` array param for an imported table. All of these values are set as fields in the
constructor.
#### Exports
@ -363,9 +371,12 @@ stack (e.g. some places where we do a swap).
Below are some performance and implementation quirks where there is a bit of an impedance mismatch between WebAssembly
and the JVM:
* WebAssembly has a nice data section for byte arrays whereas the JVM does not. Right now we build a byte array from
a bunch of consts at runtime which is multiple operations per byte. This can bloat the class file size, but is quite
fast compared to alternatives such as string constants.
* WebAssembly has a nice data section for byte arrays whereas the JVM does not. Right now we use a single-byte-char
string constant (i.e. ISO-8859 charset). This saves class file size, but this means we call `String::getBytes` on
init to load bytes from the string constant. Due to the JVM using an unsigned 16-bit int as the string constant
length, the maximum byte length is 65536. Since the string constants are stored as UTF-8 constants, they can be up to
four bytes a character. Therefore, we populate memory in data chunks no larger than 16300 (nice round number to make
sure that even in the worse case of 4 bytes per char in UTF-8 view, we're still under the max).
* The JVM makes no guarantees about trailing bits being preserved on NaN floating point representations like WebAssembly
does. This causes some mismatch on WebAssembly tests depending on how the JVM "feels" (I haven't dug into why some
bit patterns stay and some don't when NaNs are passed through methods).
@ -417,6 +428,8 @@ WASM compiled from Rust, C, Java, etc if e.g. they all have their own way of han
definition of an importable set of modules that does all of these things, even if it's in WebIDL. I dunno, maybe the
effort is already there, I haven't really looked.
There is https://github.com/konsoletyper/teavm
**So I can compile something in C via Emscripten and have it run on the JVM with this?**
Yes, but work is required. WebAssembly is lacking any kind of standard library. So Emscripten will either embed it or

View File

@ -13,4 +13,5 @@ public @interface WasmImport {
WasmExternalKind kind();
int resizableLimitInitial() default -1;
int resizableLimitMaximum() default -1;
boolean globalSetter() default false;
}

View File

@ -2,8 +2,8 @@ group 'asmble'
version '0.2.0'
buildscript {
ext.kotlin_version = '1.2.30'
ext.asm_version = '5.2'
ext.kotlin_version = '1.2.61'
ext.asm_version = '6.2.1'
repositories {
mavenCentral()
@ -14,17 +14,34 @@ buildscript {
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
classpath 'me.champeau.gradle:jmh-gradle-plugin:0.4.5'
classpath 'com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4'
}
}
allprojects {
apply plugin: 'java'
group 'com.github.cretz.asmble'
version '0.4.11-fl'
// skips building and running for the specified examples
ext.skipExamples = ['c-simple', 'go-simple', 'rust-regex']
// todo disabling Rust regex is temporary because some strings in wasm code exceed the size in 65353 bytes.
repositories {
mavenCentral()
}
}
project(':annotations') {
javadoc {
options.links 'https://docs.oracle.com/javase/8/docs/api/'
// TODO: change when https://github.com/gradle/gradle/issues/2354 is fixed
options.addStringOption 'Xdoclint:all', '-Xdoclint:-missing'
}
publishSettings(project, 'asmble-annotations', 'Asmble WASM Annotations')
}
project(':compiler') {
apply plugin: 'kotlin'
apply plugin: 'application'
@ -41,10 +58,12 @@ project(':compiler') {
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
compile "org.ow2.asm:asm-tree:$asm_version"
compile "org.ow2.asm:asm-util:$asm_version"
compile "org.ow2.asm:asm-commons:$asm_version"
testCompile 'junit:junit:4.12'
testCompile "org.jetbrains.kotlin:kotlin-test-junit:$kotlin_version"
testCompile "org.ow2.asm:asm-debug-all:$asm_version"
}
publishSettings(project, 'asmble-compiler', 'Asmble WASM Compiler')
}
project(':examples') {
@ -52,7 +71,7 @@ project(':examples') {
dependencies {
compileOnly project(':compiler')
}
// C/C++ example helpers
task cToWasm {
@ -62,7 +81,7 @@ project(':examples') {
def cFileName = fileTree(dir: 'src', includes: ['*.c']).files.iterator().next()
commandLine 'clang', '--target=wasm32-unknown-unknown-wasm', '-O3', cFileName, '-c', '-o', 'build/lib.wasm'
}
}
}
}
task showCWast(type: JavaExec) {
@ -85,6 +104,31 @@ project(':examples') {
}
}
// Go example helpers
task goToWasm {
doFirst {
mkdir 'build'
exec {
def goFileName = fileTree(dir: '.', includes: ['*.go']).files.iterator().next()
environment 'GOOS': 'js', 'GOARCH': 'wasm'
commandLine 'go', 'build', '-o', 'build/lib.wasm', goFileName
}
}
}
task compileGoWasm(type: JavaExec) {
dependsOn goToWasm
classpath configurations.compileClasspath
main = 'asmble.cli.MainKt'
doFirst {
// args 'help', 'compile'
def outFile = 'build/wasm-classes/' + wasmCompiledClassName.replace('.', '/') + '.class'
file(outFile).parentFile.mkdirs()
args 'compile', 'build/lib.wasm', wasmCompiledClassName, '-out', outFile, '-log', 'debug'
}
}
// Rust example helpers
ext.rustBuildRelease = true
@ -135,41 +179,81 @@ project(':examples') {
}
project(':examples:c-simple') {
if (project.name in skipExamples) {
println("[Note!] Building and runnig for ${project.name} was skipped")
test.onlyIf { false } // explicit skipping tests
compileJava.onlyIf { false } // explicit skipping compile
return
}
apply plugin: 'application'
ext.wasmCompiledClassName = 'asmble.generated.CSimple'
dependencies {
compile files('build/wasm-classes')
}
compileJava {
dependsOn compileCWasm
}
mainClassName = 'asmble.examples.csimple.Main'
}
project(':examples:rust-regex') {
project(':examples:go-simple') {
if (project.name in skipExamples) {
println("[Note!] Building and runnig for ${project.name} was skipped")
test.onlyIf { false } // explicit skipping tests
compileJava.onlyIf { false } // explicit skipping compile
return
}
apply plugin: 'application'
apply plugin: 'me.champeau.gradle.jmh'
ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
ext.wasmCompiledClassName = 'asmble.generated.GoSimple'
dependencies {
compile files('build/wasm-classes')
testCompile 'junit:junit:4.12'
}
compileJava {
dependsOn compileRustWasm
}
mainClassName = 'asmble.examples.rustregex.Main'
test {
testLogging.showStandardStreams = true
testLogging.events 'PASSED', 'SKIPPED'
}
jmh {
iterations = 5
warmupIterations = 5
fork = 3
dependsOn compileGoWasm
}
mainClassName = 'asmble.examples.gosimple.Main'
}
project(':examples:rust-regex') {
if (project.name in skipExamples) {
println("[Note!] Building and runnig for ${project.name} was skipped")
test.onlyIf { false } // explicit skipping tests
compileJava.onlyIf { false } // explicit skipping compile
compileTestJava.onlyIf { false } // explicit skipping compile
return
}
apply plugin: 'application'
apply plugin: 'me.champeau.gradle.jmh'
ext.wasmCompiledClassName = 'asmble.generated.RustRegex'
dependencies {
compile files('build/wasm-classes')
testCompile 'junit:junit:4.12'
}
compileJava {
dependsOn compileRustWasm
}
mainClassName = 'asmble.examples.rustregex.Main'
test {
testLogging.showStandardStreams = true
testLogging.events 'PASSED', 'SKIPPED'
}
jmh {
iterations = 5
warmupIterations = 5
fork = 3
}
}
project(':examples:rust-simple') {
if (project.name in skipExamples) {
println("[Note!] Building and runnig for ${project.name} was skipped")
test.onlyIf { false } // explicit skipping tests
compileJava.onlyIf { false } // explicit skipping compile
return
}
apply plugin: 'application'
ext.wasmCompiledClassName = 'asmble.generated.RustSimple'
dependencies {
@ -182,6 +266,12 @@ project(':examples:rust-simple') {
}
project(':examples:rust-string') {
if (project.name in skipExamples) {
println("[Note!] Building and runnig for ${project.name} was skipped")
test.onlyIf { false } // explicit skipping tests
compileJava.onlyIf { false } // explicit skipping compile
return
}
apply plugin: 'application'
ext.wasmCompiledClassName = 'asmble.generated.RustString'
dependencies {
@ -191,4 +281,61 @@ project(':examples:rust-string') {
dependsOn compileRustWasm
}
mainClassName = 'asmble.examples.ruststring.Main'
}
}
def publishSettings(project, projectName, projectDescription) {
project.with {
apply plugin: 'com.jfrog.bintray'
apply plugin: 'maven-publish'
apply plugin: 'maven'
task sourcesJar(type: Jar) {
from sourceSets.main.allJava
classifier = 'sources'
}
publishing {
publications {
MyPublication(MavenPublication) {
from components.java
groupId group
artifactId projectName
artifact sourcesJar
version version
}
}
}
bintray {
if(!hasProperty("bintrayUser") || !hasProperty("bintrayKey")) {
return
}
user = bintrayUser
key = bintrayKey
publications = ['MyPublication']
//[Default: false] Whether to override version artifacts already published
override = false
//[Default: false] Whether version should be auto published after an upload
publish = true
pkg {
repo = 'releases'
name = projectName
userOrg = 'fluencelabs'
licenses = ['MIT']
vcsUrl = 'https://github.com/fluencelabs/asmble'
version {
name = project.version
desc = projectDescription
released = new Date()
vcsTag = project.version
}
}
}
}
}

View File

@ -0,0 +1,40 @@
package asmble.compile.jvm;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
* The abstraction that describes work with the memory of the virtual machine.
*/
public abstract class MemoryBuffer {
/**
* The default implementation of MemoryBuffer that based on java.nio.DirectByteBuffer
*/
public static MemoryBuffer init(int capacity) {
return new MemoryByteBuffer(ByteBuffer.allocateDirect(capacity));
}
public abstract int capacity();
public abstract int limit();
public abstract MemoryBuffer clear();
public abstract MemoryBuffer limit(int newLimit);
public abstract MemoryBuffer position(int newPosition);
public abstract MemoryBuffer order(ByteOrder order);
public abstract MemoryBuffer duplicate();
public abstract MemoryBuffer put(byte[] arr, int offset, int length);
public abstract MemoryBuffer put(byte[] arr);
public abstract MemoryBuffer put(int index, byte b);
public abstract MemoryBuffer putInt(int index, int n);
public abstract MemoryBuffer putLong(int index, long n);
public abstract MemoryBuffer putDouble(int index, double n);
public abstract MemoryBuffer putShort(int index, short n);
public abstract MemoryBuffer putFloat(int index, float n);
public abstract byte get(int index);
public abstract int getInt(int index);
public abstract long getLong(int index);
public abstract short getShort(int index);
public abstract float getFloat(int index);
public abstract double getDouble(int index);
public abstract MemoryBuffer get(byte[] arr);
}

View File

@ -0,0 +1,8 @@
package asmble.compile.jvm;
/**
* Interface to initialize MemoryBuffer
*/
public interface MemoryBufferBuilder {
MemoryBuffer build(int capacity);
}

View File

@ -0,0 +1 @@
Taken from https://github.com/cretz/msplit

View File

@ -0,0 +1,286 @@
package asmble.compile.jvm.msplit;
import org.objectweb.asm.Label;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;
import java.util.*;
import static asmble.compile.jvm.msplit.Util.*;
/** Splits a method into two */
public class SplitMethod {
protected final int api;
/** @param api Same as for {@link org.objectweb.asm.MethodVisitor#MethodVisitor(int)} or any other ASM class */
public SplitMethod(int api) { this.api = api; }
/**
* Calls {@link #split(String, MethodNode, int, int, int)} with minSize as 20% + 1 of the original, maxSize as
* 70% + 1 of the original, and firstAtLeast as maxSize. The original method is never modified and the result can
* be null if no split points are found.
*/
public Result split(String owner, MethodNode method) {
// Between 20% + 1 and 70% + 1 of size
int insnCount = method.instructions.size();
int minSize = (int) (insnCount * 0.2) + 1;
int maxSize = (int) (insnCount * 0.7) + 1;
return split(owner, method, minSize, maxSize, maxSize);
}
/**
* Splits the given method into two. This uses a {@link Splitter} to consistently create
* {@link asmble.compile.jvm.msplit.Splitter.SplitPoint}s until one reaches firstAtLeast or the largest otherwise, and then calls
* {@link #fromSplitPoint(String, MethodNode, Splitter.SplitPoint)}.
*
* @param owner The internal name of the owning class. Needed when splitting to call the split off method.
* @param method The method to split, never modified
* @param minSize The minimum number of instructions the split off method must have
* @param maxSize The maximum number of instructions the split off method can have
* @param firstAtLeast The number of instructions that, when first reached, will immediately be used without
* continuing. Since split points are streamed, this allows splitting without waiting to
* find the largest overall. If this is &lt= 0, it will not apply and all split points will be
* checked to find the largest before doing the split.
* @return The resulting split method or null if there were no split points found
*/
public Result split(String owner, MethodNode method, int minSize, int maxSize, int firstAtLeast) {
// Get the largest split point
Splitter.SplitPoint largest = null;
for (Splitter.SplitPoint point : new Splitter(api, owner, method, minSize, maxSize)) {
if (largest == null || point.length > largest.length) {
largest = point;
// Early exit?
if (firstAtLeast > 0 && largest.length >= firstAtLeast) break;
}
}
if (largest == null) return null;
return fromSplitPoint(owner, method, largest);
}
/**
* Split the given method at the given split point. Called by {@link #split(String, MethodNode, int, int, int)}. The
* original method is never modified.
*/
public Result fromSplitPoint(String owner, MethodNode orig, Splitter.SplitPoint splitPoint) {
MethodNode splitOff = createSplitOffMethod(orig, splitPoint);
MethodNode trimmed = createTrimmedMethod(owner, orig, splitOff, splitPoint);
return new Result(trimmed, splitOff);
}
protected MethodNode createSplitOffMethod(MethodNode orig, Splitter.SplitPoint splitPoint) {
// The new method is a static synthetic method named method.name + "$split" that returns an object array
// Key is previous local index, value is new local index
Map<Integer, Integer> localsMap = new HashMap<>();
// The new method's parameters are all stack items + all read locals
List<Type> args = new ArrayList<>(splitPoint.neededFromStackAtStart);
splitPoint.localsRead.forEach((index, type) -> {
args.add(type);
localsMap.put(index, args.size() - 1);
});
// Create the new method
MethodNode newMethod = new MethodNode(api,
Opcodes.ACC_STATIC + Opcodes.ACC_PRIVATE + Opcodes.ACC_SYNTHETIC, orig.name + "$split",
Type.getMethodDescriptor(Type.getType(Object[].class), args.toArray(new Type[0])), null, null);
// Add the written locals to the map that are not already there
int newLocalIndex = args.size();
for (Integer key : splitPoint.localsWritten.keySet()) {
if (!localsMap.containsKey(key)) {
localsMap.put(key, newLocalIndex);
newLocalIndex++;
}
}
// First set of instructions is pushing the new stack from the params
for (int i = 0; i < splitPoint.neededFromStackAtStart.size(); i++) {
Type item = splitPoint.neededFromStackAtStart.get(i);
newMethod.visitVarInsn(loadOpFromType(item), i);
}
// Next set of instructions comes verbatim from the original, but we have to change the local indexes
Set<Label> seenLabels = new HashSet<>();
for (int i = 0; i < splitPoint.length; i++) {
AbstractInsnNode insn = orig.instructions.get(i + splitPoint.start);
// Skip frames
if (insn instanceof FrameNode) continue;
// Store the label
if (insn instanceof LabelNode) seenLabels.add(((LabelNode) insn).getLabel());
// Change the local if needed
if (insn instanceof VarInsnNode) {
insn = insn.clone(Collections.emptyMap());
((VarInsnNode) insn).var = localsMap.get(((VarInsnNode) insn).var);
} else if (insn instanceof IincInsnNode) {
insn = insn.clone(Collections.emptyMap());
((VarInsnNode) insn).var = localsMap.get(((VarInsnNode) insn).var);
}
insn.accept(newMethod);
}
// Final set of instructions is an object array of stack to set and then locals written
// Create the object array
int retArrSize = splitPoint.putOnStackAtEnd.size() + splitPoint.localsWritten.size();
intConst(retArrSize).accept(newMethod);
newMethod.visitTypeInsn(Opcodes.ANEWARRAY, OBJECT_TYPE.getInternalName());
// So, we're going to store the arr in the next avail local
int retArrLocalIndex = newLocalIndex;
newMethod.visitVarInsn(Opcodes.ASTORE, retArrLocalIndex);
// Now go over each stack item and load the arr, swap w/ the stack, add the index, swap with the stack, and store
for (int i = splitPoint.putOnStackAtEnd.size() - 1; i >= 0; i--) {
Type item = splitPoint.putOnStackAtEnd.get(i);
// Box the item on the stack if necessary
boxStackIfNecessary(item, newMethod);
// Load the array
newMethod.visitVarInsn(Opcodes.ALOAD, retArrLocalIndex);
// Swap to put stack back on top
newMethod.visitInsn(Opcodes.SWAP);
// Add the index
intConst(i).accept(newMethod);
// Swap to put the stack value back on top
newMethod.visitInsn(Opcodes.SWAP);
// Now that we have arr, index, value, we can store in the array
newMethod.visitInsn(Opcodes.AASTORE);
}
// Do the same with written locals
int currIndex = splitPoint.putOnStackAtEnd.size();
for (Integer index : splitPoint.localsWritten.keySet()) {
Type item = splitPoint.localsWritten.get(index);
// Load the array
newMethod.visitVarInsn(Opcodes.ALOAD, retArrLocalIndex);
// Add the arr index
intConst(currIndex).accept(newMethod);
currIndex++;
// Load the var
newMethod.visitVarInsn(loadOpFromType(item), localsMap.get(index));
// Box it if necessary
boxStackIfNecessary(item, newMethod);
// Store in array
newMethod.visitInsn(Opcodes.AASTORE);
}
// Load the array out and return it
newMethod.visitVarInsn(Opcodes.ALOAD, retArrLocalIndex);
newMethod.visitInsn(Opcodes.ARETURN);
// Any try catch blocks that start in here
for (TryCatchBlockNode tryCatch : orig.tryCatchBlocks) {
if (seenLabels.contains(tryCatch.start.getLabel())) tryCatch.accept(newMethod);
}
// Reset the labels
newMethod.instructions.resetLabels();
return newMethod;
}
protected MethodNode createTrimmedMethod(String owner, MethodNode orig,
MethodNode splitOff, Splitter.SplitPoint splitPoint) {
// The trimmed method is the same as the original, yet the split area is replaced with a call to the split off
// portion. Before calling the split-off, we have to add locals to the stack part. Then afterwards, we have to
// replace the stack and written locals.
// Effectively clone the orig
MethodNode newMethod = new MethodNode(api, orig.access, orig.name, orig.desc,
orig.signature, orig.exceptions.toArray(new String[0]));
orig.accept(newMethod);
// Remove all insns, we'll re-add the ones outside the split range
newMethod.instructions.clear();
// Remove all try catch blocks and keep track of seen labels, we'll re-add them at the end
newMethod.tryCatchBlocks.clear();
Set<Label> seenLabels = new HashSet<>();
// Also keep track of the locals that have been stored, need to know
Set<Integer> seenStoredLocals = new HashSet<>();
// If this is an instance method, we consider "0" (i.e. "this") as seen
if ((orig.access & Opcodes.ACC_STATIC) == 0) seenStoredLocals.add(0);
// Add the insns before split
for (int i = 0; i < splitPoint.start; i++) {
AbstractInsnNode insn = orig.instructions.get(i + splitPoint.start);
// Skip frames
if (insn instanceof FrameNode) continue;
// Record label
if (insn instanceof LabelNode) seenLabels.add(((LabelNode) insn).getLabel());
// Check a local store has happened
if (insn instanceof VarInsnNode && isStoreOp(insn.getOpcode())) seenStoredLocals.add(((VarInsnNode) insn).var);
insn.accept(newMethod);
}
// Push all the read locals on the stack
splitPoint.localsRead.forEach((index, type) -> {
// We've seen a store for this, so just load it, otherwise use a zero val
// TODO: safe? if not, maybe just put at the top of the method a bunch of defaulted locals?
if (seenStoredLocals.contains(index)) newMethod.visitVarInsn(loadOpFromType(type), index);
else zeroVal(type).accept(newMethod);
});
// Invoke the split off method
newMethod.visitMethodInsn(Opcodes.INVOKESTATIC, owner, splitOff.name, splitOff.desc, false);
// Now the object array is on the stack which contains stack pieces + written locals
// Take off the locals
int localArrIndex = splitPoint.putOnStackAtEnd.size();
for (Integer index : splitPoint.localsWritten.keySet()) {
// Dupe the array
newMethod.visitInsn(Opcodes.DUP);
// Put the index on the stack
intConst(localArrIndex).accept(newMethod);
localArrIndex++;
// Load the written local
Type item = splitPoint.localsWritten.get(index);
newMethod.visitInsn(Opcodes.AALOAD);
// Cast to local type
if (!item.equals(OBJECT_TYPE)) {
newMethod.visitTypeInsn(Opcodes.CHECKCAST, boxedTypeIfNecessary(item).getInternalName());
}
// Unbox if necessary
unboxStackIfNecessary(item, newMethod);
// Store in the local
newMethod.visitVarInsn(storeOpFromType(item), index);
}
// Now just load up the stack
for (int i = 0; i < splitPoint.putOnStackAtEnd.size(); i++) {
boolean last = i == splitPoint.putOnStackAtEnd.size() - 1;
// Since the loop started with the array, we only dupe the array every time but the last
if (!last) newMethod.visitInsn(Opcodes.DUP);
// Put the index on the stack
intConst(i).accept(newMethod);
// Load the stack item
Type item = splitPoint.putOnStackAtEnd.get(i);
newMethod.visitInsn(Opcodes.AALOAD);
// Cast to local type
if (!item.equals(OBJECT_TYPE)) {
newMethod.visitTypeInsn(Opcodes.CHECKCAST, boxedTypeIfNecessary(item).getInternalName());
}
// Unbox if necessary
unboxStackIfNecessary(item, newMethod);
// For all but the last stack item, we need to swap with the arr ref above.
if (!last) {
// Note if the stack item takes two slots, we do a form of dup then pop since there's no swap1x2
if (item == Type.LONG_TYPE || item == Type.DOUBLE_TYPE) {
newMethod.visitInsn(Opcodes.DUP_X2);
newMethod.visitInsn(Opcodes.POP);
} else {
newMethod.visitInsn(Opcodes.SWAP);
}
}
}
// Now we have restored all locals and all stack...add the rest of the insns after the split
for (int i = splitPoint.start + splitPoint.length; i < orig.instructions.size(); i++) {
AbstractInsnNode insn = orig.instructions.get(i + splitPoint.start);
// Skip frames
if (insn instanceof FrameNode) continue;
// Record label
if (insn instanceof LabelNode) seenLabels.add(((LabelNode) insn).getLabel());
insn.accept(newMethod);
}
// Add any try catch blocks that started in here
for (TryCatchBlockNode tryCatch : orig.tryCatchBlocks) {
if (seenLabels.contains(tryCatch.start.getLabel())) tryCatch.accept(newMethod);
}
// Reset the labels
newMethod.instructions.resetLabels();
return newMethod;
}
/** Result of a split method */
public static class Result {
/** A copy of the original method, but changed to invoke {@link #splitOffMethod} */
public final MethodNode trimmedMethod;
/** The new method that was split off the original and is called by {@link #splitOffMethod} */
public final MethodNode splitOffMethod;
public Result(MethodNode trimmedMethod, MethodNode splitOffMethod) {
this.trimmedMethod = trimmedMethod;
this.splitOffMethod = splitOffMethod;
}
}
}

View File

@ -0,0 +1,392 @@
package asmble.compile.jvm.msplit;
import org.objectweb.asm.Label;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.AnalyzerAdapter;
import org.objectweb.asm.tree.*;
import java.util.*;
import static asmble.compile.jvm.msplit.Util.*;
/** For a given method, iterate over possible split points */
public class Splitter implements Iterable<Splitter.SplitPoint> {
protected final int api;
protected final String owner;
protected final MethodNode method;
protected final int minSize;
protected final int maxSize;
/**
* @param api Same as for {@link org.objectweb.asm.MethodVisitor#MethodVisitor(int)} or any other ASM class
* @param owner Internal name of the method's owner
* @param method The method to find split points for
* @param minSize The minimum number of instructions required for the split point to be valid
* @param maxSize The maximum number of instructions that split points cannot exceeed
*/
public Splitter(int api, String owner, MethodNode method, int minSize, int maxSize) {
this.api = api;
this.owner = owner;
this.method = method;
this.minSize = minSize;
this.maxSize = maxSize;
}
@Override
public Iterator<SplitPoint> iterator() { return new Iter(); }
// Types are always int, float, long, double, or ref (no other primitives)
/** A split point in a method that can be split off into another method */
public static class SplitPoint {
/**
* The locals read in this split area, keyed by index. Value type is always int, float, long, double, or object.
*/
public final SortedMap<Integer, Type> localsRead;
/**
* The locals written in this split area, keyed by index. Value type is always int, float, long, double, or object.
*/
public final SortedMap<Integer, Type> localsWritten;
/**
* The values of the stack needed at the start of this split area. Type is always int, float, long, double, or
* object.
*/
public final List<Type> neededFromStackAtStart;
/**
* The values of the stack at the end of this split area that are needed to put back on the original. Type is always
* int, float, long, double, or object.
*/
public final List<Type> putOnStackAtEnd;
/**
* The instruction index this split area begins at.
*/
public final int start;
/**
* The number of instructions this split area has.
*/
public final int length;
public SplitPoint(SortedMap<Integer, Type> localsRead, SortedMap<Integer, Type>localsWritten,
List<Type> neededFromStackAtStart, List<Type> putOnStackAtEnd, int start, int length) {
this.localsRead = localsRead;
this.localsWritten = localsWritten;
this.neededFromStackAtStart = neededFromStackAtStart;
this.putOnStackAtEnd = putOnStackAtEnd;
this.start = start;
this.length = length;
}
}
protected int compareInsnIndexes(AbstractInsnNode o1, AbstractInsnNode o2) {
return Integer.compare(method.instructions.indexOf(o1), method.instructions.indexOf(o2));
}
protected class Iter implements Iterator<SplitPoint> {
protected final AbstractInsnNode[] insns;
protected final List<TryCatchBlockNode> tryCatchBlocks;
protected int currIndex = -1;
protected boolean peeked;
protected SplitPoint peekedValue;
protected Iter() {
insns = method.instructions.toArray();
tryCatchBlocks = new ArrayList<>(method.tryCatchBlocks);
// Must be sorted by earliest starting index then earliest end index then earliest handler
tryCatchBlocks.sort((o1, o2) -> {
int cmp = compareInsnIndexes(o1.start, o2.start);
if (cmp == 0) compareInsnIndexes(o1.end, o2.end);
if (cmp == 0) compareInsnIndexes(o1.handler, o2.handler);
return cmp;
});
}
@Override
public boolean hasNext() {
if (!peeked) {
peeked = true;
peekedValue = nextOrNull();
}
return peekedValue != null;
}
@Override
public SplitPoint next() {
// If we've peeked in hasNext, use that
SplitPoint ret;
if (peeked) {
peeked = false;
ret = peekedValue;
} else {
ret = nextOrNull();
}
if (ret == null) throw new NoSuchElementException();
return ret;
}
protected SplitPoint nextOrNull() {
// Try for each index
while (++currIndex + minSize <= insns.length) {
SplitPoint longest = longestForCurrIndex();
if (longest != null) return longest;
}
return null;
}
protected SplitPoint longestForCurrIndex() {
// As a special case, if the previous insn was a line number, that was good enough
if (currIndex - 1 >- 0 && insns[currIndex - 1] instanceof LineNumberNode) return null;
// Build the info object
InsnTraverseInfo info = new InsnTraverseInfo();
info.startIndex = currIndex;
info.endIndex = Math.min(currIndex + maxSize - 1, insns.length - 1);
// Reduce the end based on try/catch blocks the start is in or that jump to
constrainEndByTryCatchBlocks(info);
// Reduce the end based on any jumps within
constrainEndByInternalJumps(info);
// Reduce the end based on any jumps into
constrainEndByExternalJumps(info);
// Make sure we didn't reduce the end too far
if (info.getSize() < minSize) return null;
// Now that we have our largest range from the start index, we can go over each updating the local refs and stack
// For the stack, we are going to use the
return splitPointFromInfo(info);
}
protected void constrainEndByTryCatchBlocks(InsnTraverseInfo info) {
// Go over all the try/catch blocks, sorted by earliest
for (TryCatchBlockNode block : tryCatchBlocks) {
int handleIndex = method.instructions.indexOf(block.handler);
int startIndex = method.instructions.indexOf(block.start);
int endIndex = method.instructions.indexOf(block.end) - 1;
boolean catchWithinDisallowed;
if (info.startIndex <= startIndex && info.endIndex >= endIndex) {
// The try block is entirely inside the range...
catchWithinDisallowed = false;
// Since it's entirely within, we need the catch handler within too
if (handleIndex < info.startIndex || handleIndex > info.endIndex) {
// Well, it's not within, so that means we can't include this try block at all
info.endIndex = Math.min(info.endIndex, startIndex - 1);
}
} else if (info.startIndex > startIndex && info.endIndex > endIndex) {
// The try block started before this range, but ends inside of it...
// The end has to be changed to the block's end so it doesn't go over the boundary
info.endIndex = Math.min(info.endIndex, endIndex);
// The catch can't jump in here
catchWithinDisallowed = true;
} else if (info.startIndex <= startIndex && info.endIndex < endIndex) {
// The try block started in this range, but ends outside of it...
// Can't have the block then, reduce it to before the start
info.endIndex = Math.min(info.endIndex, startIndex - 1);
// Since we don't have the block, we can't jump in here either
catchWithinDisallowed = true;
} else {
// The try block is completely outside, just restrict the catch from jumping in
catchWithinDisallowed = true;
}
// If the catch is within and not allowed to be, we have to change the end to before it
if (catchWithinDisallowed && info.startIndex <= handleIndex && info.endIndex >= handleIndex) {
info.endIndex = Math.min(info.endIndex, handleIndex - 1);
}
}
}
protected void constrainEndByInternalJumps(InsnTraverseInfo info) {
for (int i = info.startIndex; i <= info.endIndex; i++) {
AbstractInsnNode node = insns[i];
int earliestIndex;
int furthestIndex;
if (node instanceof JumpInsnNode) {
earliestIndex = method.instructions.indexOf(((JumpInsnNode) node).label);
furthestIndex = earliestIndex;
} else if (node instanceof TableSwitchInsnNode) {
earliestIndex = method.instructions.indexOf(((TableSwitchInsnNode) node).dflt);
furthestIndex = earliestIndex;
for (LabelNode label : ((TableSwitchInsnNode) node).labels) {
int index = method.instructions.indexOf(label);
earliestIndex = Math.min(earliestIndex, index);
furthestIndex = Math.max(furthestIndex, index);
}
} else if (node instanceof LookupSwitchInsnNode) {
earliestIndex = method.instructions.indexOf(((LookupSwitchInsnNode) node).dflt);
furthestIndex = earliestIndex;
for (LabelNode label : ((LookupSwitchInsnNode) node).labels) {
int index = method.instructions.indexOf(label);
earliestIndex = Math.min(earliestIndex, index);
furthestIndex = Math.max(furthestIndex, index);
}
} else continue;
// Stop here if any indexes are out of range, otherwise, change end
if (earliestIndex < info.startIndex || furthestIndex > info.endIndex) {
info.endIndex = i - 1;
return;
}
info.endIndex = Math.max(info.endIndex, furthestIndex);
}
}
protected void constrainEndByExternalJumps(InsnTraverseInfo info) {
// Basically, if any external jumps jump into our range, that can't be included in the range
for (int i = 0; i < insns.length; i++) {
if (i >= info.startIndex && i <= info.endIndex) continue;
AbstractInsnNode node = insns[i];
if (node instanceof JumpInsnNode) {
int index = method.instructions.indexOf(((JumpInsnNode) node).label);
if (index >= info.startIndex) info.endIndex = Math.min(info.endIndex, index - 1);
} else if (node instanceof TableSwitchInsnNode) {
int index = method.instructions.indexOf(((TableSwitchInsnNode) node).dflt);
if (index >= info.startIndex) info.endIndex = Math.min(info.endIndex, index - 1);
for (LabelNode label : ((TableSwitchInsnNode) node).labels) {
index = method.instructions.indexOf(label);
if (index >= info.startIndex) info.endIndex = Math.min(info.endIndex, index - 1);
}
} else if (node instanceof LookupSwitchInsnNode) {
int index = method.instructions.indexOf(((LookupSwitchInsnNode) node).dflt);
if (index >= info.startIndex) info.endIndex = Math.min(info.endIndex, index - 1);
for (LabelNode label : ((LookupSwitchInsnNode) node).labels) {
index = method.instructions.indexOf(label);
if (index >= info.startIndex) info.endIndex = Math.min(info.endIndex, index - 1);
}
}
}
}
protected SplitPoint splitPointFromInfo(InsnTraverseInfo info) {
// We're going to use the analyzer adapter and run it for the up until the end, a step at a time
StackAndLocalTrackingAdapter adapter = new StackAndLocalTrackingAdapter(Splitter.this);
// Visit all of the insns up our start.
// XXX: I checked the source of AnalyzerAdapter to confirm I don't need any of the surrounding stuff
for (int i = 0; i < info.startIndex; i++) insns[i].accept(adapter);
// Take the stack at the start and copy it off
List<Object> stackAtStart = new ArrayList<>(adapter.stack);
// Reset some adapter state
adapter.lowestStackSize = stackAtStart.size();
adapter.localsRead.clear();
adapter.localsWritten.clear();
// Now go over the remaining range
for (int i = info.startIndex; i <= info.endIndex; i++) insns[i].accept(adapter);
// Build the split point
return new SplitPoint(
localMapFromAdapterLocalMap(adapter.localsRead, adapter.uninitializedTypes),
localMapFromAdapterLocalMap(adapter.localsWritten, adapter.uninitializedTypes),
typesFromAdapterStackRange(stackAtStart, adapter.lowestStackSize, adapter.uninitializedTypes),
typesFromAdapterStackRange(adapter.stack, adapter.lowestStackSize, adapter.uninitializedTypes),
info.startIndex,
info.getSize()
);
}
protected SortedMap<Integer, Type> localMapFromAdapterLocalMap(
SortedMap<Integer, Object> map, Map<Object, Object> uninitializedTypes) {
SortedMap<Integer, Type> ret = new TreeMap<>();
map.forEach((k, v) -> ret.put(k, typeFromAdapterStackItem(v, uninitializedTypes)));
return ret;
}
protected List<Type> typesFromAdapterStackRange(
List<Object> stack, int start, Map<Object, Object> uninitializedTypes) {
List<Type> ret = new ArrayList<>();
for (int i = start; i < stack.size(); i++) {
Object item = stack.get(i);
ret.add(typeFromAdapterStackItem(item, uninitializedTypes));
// Jump an extra spot for longs and doubles
if (item == Opcodes.LONG || item == Opcodes.DOUBLE) {
if (stack.get(++i) != Opcodes.TOP) throw new IllegalStateException("Expected top after long/double");
}
}
return ret;
}
protected Type typeFromAdapterStackItem(Object item, Map<Object, Object> uninitializedTypes) {
if (item == Opcodes.INTEGER) return Type.INT_TYPE;
else if (item == Opcodes.FLOAT) return Type.FLOAT_TYPE;
else if (item == Opcodes.LONG) return Type.LONG_TYPE;
else if (item == Opcodes.DOUBLE) return Type.DOUBLE_TYPE;
else if (item == Opcodes.NULL) return OBJECT_TYPE;
else if (item == Opcodes.UNINITIALIZED_THIS) return Type.getObjectType(owner);
else if (item instanceof Label) return Type.getObjectType((String) uninitializedTypes.get(item));
else if (item instanceof String) return Type.getObjectType((String) item);
else throw new IllegalStateException("Unrecognized stack item: " + item);
}
}
protected static class StackAndLocalTrackingAdapter extends AnalyzerAdapter {
public int lowestStackSize;
public final SortedMap<Integer, Object> localsRead = new TreeMap<>();
public final SortedMap<Integer, Object> localsWritten = new TreeMap<>();
protected StackAndLocalTrackingAdapter(Splitter splitter) {
super(splitter.api, splitter.owner, splitter.method.access, splitter.method.name, splitter.method.desc, null);
stack = new SizeChangeNotifyList<Object>() {
@Override
protected void onSizeChanged() { lowestStackSize = Math.min(lowestStackSize, size()); }
};
}
@Override
public void visitVarInsn(int opcode, int var) {
switch (opcode) {
case Opcodes.ILOAD:
case Opcodes.LLOAD:
case Opcodes.FLOAD:
case Opcodes.DLOAD:
case Opcodes.ALOAD:
localsRead.put(var, locals.get(var));
break;
case Opcodes.ISTORE:
case Opcodes.FSTORE:
case Opcodes.ASTORE:
localsWritten.put(var, stack.get(stack.size() - 1));
break;
case Opcodes.LSTORE:
case Opcodes.DSTORE:
localsWritten.put(var, stack.get(stack.size() - 2));
break;
}
super.visitVarInsn(opcode, var);
}
@Override
public void visitIincInsn(int var, int increment) {
localsRead.put(var, Type.INT_TYPE);
localsWritten.put(var, Type.INT_TYPE);
super.visitIincInsn(var, increment);
}
}
protected static class SizeChangeNotifyList<T> extends AbstractList<T> {
protected final ArrayList<T> list = new ArrayList<>();
protected void onSizeChanged() { }
@Override
public T get(int index) { return list.get(index); }
@Override
public int size() { return list.size(); }
@Override
public T set(int index, T element) { return list.set(index, element); }
@Override
public void add(int index, T element) {
list.add(index, element);
onSizeChanged();
}
@Override
public T remove(int index) {
T ret = list.remove(index);
onSizeChanged();
return ret;
}
}
protected static class InsnTraverseInfo {
public int startIndex;
// Can only shrink, never increase in size
public int endIndex;
public int getSize() { return endIndex - startIndex + 1; }
}
}

View File

@ -0,0 +1,84 @@
package asmble.compile.jvm.msplit;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;
class Util {
private Util() { }
static final Type OBJECT_TYPE = Type.getType(Object.class);
static AbstractInsnNode zeroVal(Type type) {
if (type == Type.INT_TYPE) return new InsnNode(Opcodes.ICONST_0);
else if (type == Type.LONG_TYPE) return new InsnNode(Opcodes.LCONST_0);
else if (type == Type.FLOAT_TYPE) return new InsnNode(Opcodes.FCONST_0);
else if (type == Type.DOUBLE_TYPE) return new InsnNode(Opcodes.DCONST_0);
else return new InsnNode(Opcodes.ACONST_NULL);
}
static boolean isStoreOp(int opcode) {
return opcode == Opcodes.ISTORE || opcode == Opcodes.LSTORE || opcode == Opcodes.FSTORE ||
opcode == Opcodes.DSTORE || opcode == Opcodes.ASTORE;
}
static int storeOpFromType(Type type) {
if (type == Type.INT_TYPE) return Opcodes.ISTORE;
else if (type == Type.LONG_TYPE) return Opcodes.LSTORE;
else if (type == Type.FLOAT_TYPE) return Opcodes.FSTORE;
else if (type == Type.DOUBLE_TYPE) return Opcodes.DSTORE;
else return Opcodes.ASTORE;
}
static int loadOpFromType(Type type) {
if (type == Type.INT_TYPE) return Opcodes.ILOAD;
else if (type == Type.LONG_TYPE) return Opcodes.LLOAD;
else if (type == Type.FLOAT_TYPE) return Opcodes.FLOAD;
else if (type == Type.DOUBLE_TYPE) return Opcodes.DLOAD;
else return Opcodes.ALOAD;
}
static Type boxedTypeIfNecessary(Type type) {
if (type == Type.INT_TYPE) return Type.getType(Integer.class);
else if (type == Type.LONG_TYPE) return Type.getType(Long.class);
else if (type == Type.FLOAT_TYPE) return Type.getType(Float.class);
else if (type == Type.DOUBLE_TYPE) return Type.getType(Double.class);
else return type;
}
static void boxStackIfNecessary(Type type, MethodNode method) {
if (type == Type.INT_TYPE) boxCall(Integer.class, type).accept(method);
else if (type == Type.FLOAT_TYPE) boxCall(Float.class, type).accept(method);
else if (type == Type.LONG_TYPE) boxCall(Long.class, type).accept(method);
else if (type == Type.DOUBLE_TYPE) boxCall(Double.class, type).accept(method);
}
static void unboxStackIfNecessary(Type type, MethodNode method) {
if (type == Type.INT_TYPE) method.visitMethodInsn(Opcodes.INVOKEVIRTUAL,
"java/lang/Integer", "intValue", Type.getMethodDescriptor(Type.INT_TYPE), false);
else if (type == Type.FLOAT_TYPE) method.visitMethodInsn(Opcodes.INVOKEVIRTUAL,
"java/lang/Float", "floatValue", Type.getMethodDescriptor(Type.FLOAT_TYPE), false);
else if (type == Type.LONG_TYPE) method.visitMethodInsn(Opcodes.INVOKEVIRTUAL,
"java/lang/Long", "longValue", Type.getMethodDescriptor(Type.LONG_TYPE), false);
else if (type == Type.DOUBLE_TYPE) method.visitMethodInsn(Opcodes.INVOKEVIRTUAL,
"java/lang/Double", "doubleValue", Type.getMethodDescriptor(Type.DOUBLE_TYPE), false);
}
static AbstractInsnNode intConst(int v) {
switch (v) {
case -1: return new InsnNode(Opcodes.ICONST_M1);
case 0: return new InsnNode(Opcodes.ICONST_0);
case 1: return new InsnNode(Opcodes.ICONST_1);
case 2: return new InsnNode(Opcodes.ICONST_2);
case 3: return new InsnNode(Opcodes.ICONST_3);
case 4: return new InsnNode(Opcodes.ICONST_4);
case 5: return new InsnNode(Opcodes.ICONST_5);
default: return new LdcInsnNode(v);
}
}
static MethodInsnNode boxCall(Class<?> boxType, Type primType) {
return new MethodInsnNode(Opcodes.INVOKESTATIC, Type.getInternalName(boxType),
"valueOf", Type.getMethodDescriptor(Type.getType(boxType), primType), false);
}
}

View File

@ -3,7 +3,19 @@ package asmble.ast
import java.util.*
import kotlin.reflect.KClass
/**
* All WebAssembly AST nodes as static inner classes.
*/
sealed class Node {
/**
* Wasm module definition.
*
* The unit of WebAssembly code is the module. A module collects definitions
* for types, functions, tables, memories, and globals. In addition, it can
* declare imports and exports and provide initialization logic in the form
* of data and element segments or a start function.
*/
data class Module(
val types: List<Type.Func> = emptyList(),
val imports: List<Import> = emptyList(),
@ -15,6 +27,7 @@ sealed class Node {
val elems: List<Elem> = emptyList(),
val funcs: List<Func> = emptyList(),
val data: List<Data> = emptyList(),
val names: NameSection? = null,
val customSections: List<CustomSection> = emptyList()
) : Node()
@ -148,6 +161,12 @@ sealed class Node {
}
}
data class NameSection(
val moduleName: String?,
val funcNames: Map<Int, String>,
val localNames: Map<Int, Map<Int, String>>
) : Node()
sealed class Instr : Node() {
fun op() = InstrOp.classToOpMap[this::class] ?: throw Exception("No op found for ${this::class}")
@ -165,12 +184,15 @@ sealed class Node {
interface Const<out T : Number> : Args { val value: T }
}
// Control flow
// Control instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#control-instructions]
object Unreachable : Instr(), Args.None
object Nop : Instr(), Args.None
data class Block(override val type: Type.Value?) : Instr(), Args.Type
data class Loop(override val type: Type.Value?) : Instr(), Args.Type
data class If(override val type: Type.Value?) : Instr(), Args.Type
object Else : Instr(), Args.None
object End : Instr(), Args.None
data class Br(override val relativeDepth: Int) : Instr(), Args.RelativeDepth
@ -181,25 +203,27 @@ sealed class Node {
) : Instr(), Args.Table
object Return : Instr()
// Call operators
data class Call(override val index: Int) : Instr(), Args.Index
data class CallIndirect(
override val index: Int,
override val reserved: Boolean
) : Instr(), Args.ReservedIndex
// Parametric operators
// Parametric instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#parametric-instructions]
object Drop : Instr(), Args.None
object Select : Instr(), Args.None
// Variable access
// Variable instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#variable-instructions]
data class GetLocal(override val index: Int) : Instr(), Args.Index
data class SetLocal(override val index: Int) : Instr(), Args.Index
data class TeeLocal(override val index: Int) : Instr(), Args.Index
data class GetGlobal(override val index: Int) : Instr(), Args.Index
data class SetGlobal(override val index: Int) : Instr(), Args.Index
// Memory operators
// Memory instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#memory-instructions]
data class I32Load(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class I64Load(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class F32Load(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
@ -223,10 +247,12 @@ sealed class Node {
data class I64Store8(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class I64Store16(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class I64Store32(override val align: Int, override val offset: Long) : Instr(), Args.AlignOffset
data class CurrentMemory(override val reserved: Boolean) : Instr(), Args.Reserved
data class GrowMemory(override val reserved: Boolean) : Instr(), Args.Reserved
data class MemorySize(override val reserved: Boolean) : Instr(), Args.Reserved
data class MemoryGrow(override val reserved: Boolean) : Instr(), Args.Reserved
// Constants
// Numeric instructions [https://www.w3.org/TR/2018/WD-wasm-core-1-20180215/#numeric-instructions]
// Constants operators
data class I32Const(override val value: Int) : Instr(), Args.Const<Int>
data class I64Const(override val value: Long) : Instr(), Args.Const<Long>
data class F32Const(override val value: Float) : Instr(), Args.Const<Float>
@ -511,8 +537,8 @@ sealed class Node {
opMapEntry("i64.store8", 0x3c, ::MemOpAlignOffsetArg, Instr::I64Store8, Instr.I64Store8::class)
opMapEntry("i64.store16", 0x3d, ::MemOpAlignOffsetArg, Instr::I64Store16, Instr.I64Store16::class)
opMapEntry("i64.store32", 0x3e, ::MemOpAlignOffsetArg, Instr::I64Store32, Instr.I64Store32::class)
opMapEntry("current_memory", 0x3f, ::MemOpReservedArg, Instr::CurrentMemory, Instr.CurrentMemory::class)
opMapEntry("grow_memory", 0x40, ::MemOpReservedArg, Instr::GrowMemory, Instr.GrowMemory::class)
opMapEntry("memory.size", 0x3f, ::MemOpReservedArg, Instr::MemorySize, Instr.MemorySize::class)
opMapEntry("memory.grow", 0x40, ::MemOpReservedArg, Instr::MemoryGrow, Instr.MemoryGrow::class)
opMapEntry("i32.const", 0x41, ::ConstOpIntArg, Instr::I32Const, Instr.I32Const::class)
opMapEntry("i64.const", 0x42, ::ConstOpLongArg, Instr::I64Const, Instr.I64Const::class)

View File

@ -2,10 +2,16 @@ package asmble.ast
import asmble.io.SExprToStr
/**
* Ast representation of wasm S-expressions (wast format).
* see [[https://webassembly.github.io/spec/core/text/index.html]]
*/
sealed class SExpr {
data class Multi(val vals: List<SExpr> = emptyList()) : SExpr() {
override fun toString() = SExprToStr.Compact.fromSExpr(this)
}
data class Symbol(
val contents: String = "",
val quoted: Boolean = false,
@ -15,4 +21,5 @@ sealed class SExpr {
// This is basically the same as the deprecated java.lang.String#getBytes
fun rawContentCharsToBytes() = contents.toCharArray().map(Char::toByte)
}
}

View File

@ -1,6 +1,10 @@
package asmble.ast
/**
* Ast representation of wasm script.
*/
data class Script(val commands: List<Cmd>) {
sealed class Cmd {
data class Module(val module: Node.Module, val name: String?): Cmd()
data class Register(val string: String, val name: String?): Cmd()

View File

@ -0,0 +1,250 @@
package asmble.ast
// This is a utility for walking the stack. It can do validation or just walk naively.
data class Stack(
// If some of these values below are null, the pops/pushes may appear "unknown"
val mod: CachedModule? = null,
val func: Node.Func? = null,
// Null if not tracking the current stack and all pops succeed
val current: List<Node.Type.Value>? = null,
val insnApplies: List<InsnApply> = emptyList(),
val strict: Boolean = false,
val unreachableUntilNextEndCount: Int = 0
) {
fun next(v: Node.Instr, callFuncTypeOverride: Node.Type.Func? = null) = insnApply(v) {
// If we're unreachable, and not an end, we skip and move on
if (unreachableUntilNextEndCount > 0 && v !is Node.Instr.End) {
// If it's a block, we increase it because we'll see another end
return@insnApply if (v is Node.Instr.Args.Type) unreachable(unreachableUntilNextEndCount + 1) else nop()
}
when (v) {
is Node.Instr.Nop, is Node.Instr.Block, is Node.Instr.Loop -> nop()
is Node.Instr.If, is Node.Instr.BrIf -> popI32()
is Node.Instr.Return -> (func?.type?.ret?.let { pop(it) } ?: nop()) + unreachable(1)
is Node.Instr.Unreachable -> unreachable(1)
is Node.Instr.End, is Node.Instr.Else -> {
// Put back what was before the last block and add the block's type
// Go backwards to find the starting block
var currDepth = 0
val found = insnApplies.findLast {
when (it.insn) {
is Node.Instr.End -> { currDepth++; false }
is Node.Instr.Args.Type -> if (currDepth > 0) { currDepth--; false } else true
else -> false
}
}?.takeIf {
// When it's else, needs to be if
v !is Node.Instr.Else || it.insn is Node.Instr.If
}
val changes = when {
found != null && found.insn is Node.Instr.Args.Type &&
found.stackAtBeginning != null && this != null -> {
// Pop everything from before the block's start, then push if necessary...
// The If block includes an int at the beginning we must not include when subtracting
var preBlockStackSize = found.stackAtBeginning.size
if (found.insn is Node.Instr.If) preBlockStackSize--
val popped =
if (unreachableUntilNextEndCount > 1) nop()
else (0 until (size - preBlockStackSize)).flatMap { pop() }
// Only push if this is not an else
val pushed =
if (unreachableUntilNextEndCount > 1 || v is Node.Instr.Else) nop()
else (found.insn.type?.let { push(it) } ?: nop())
popped + pushed
}
strict -> error("Unable to find starting block for end")
else -> nop()
}
if (unreachableUntilNextEndCount > 0) changes + unreachable(unreachableUntilNextEndCount - 1)
else changes
}
is Node.Instr.Br -> unreachable(v.relativeDepth + 1)
is Node.Instr.BrTable -> popI32() + unreachable(1)
is Node.Instr.Call -> (callFuncTypeOverride ?: func(v.index)).let {
if (it == null) error("Call func type missing")
it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.CallIndirect -> (callFuncTypeOverride ?: mod?.mod?.types?.getOrNull(v.index)).let {
if (it == null) error("Call func type missing")
// We add one for the table index
popI32() + it.params.reversed().flatMap { pop(it) } + (it.ret?.let { push(it) } ?: nop())
}
is Node.Instr.Drop -> pop()
is Node.Instr.Select -> popI32() + pop().let { it + pop(it.first().type) + push(it.first().type) }
is Node.Instr.GetLocal -> push(local(v.index))
is Node.Instr.SetLocal -> pop(local(v.index))
is Node.Instr.TeeLocal -> local(v.index).let { pop(it) + push(it) }
is Node.Instr.GetGlobal -> push(global(v.index))
is Node.Instr.SetGlobal -> pop(global(v.index))
is Node.Instr.I32Load, is Node.Instr.I32Load8S, is Node.Instr.I32Load8U,
is Node.Instr.I32Load16U, is Node.Instr.I32Load16S -> popI32() + pushI32()
is Node.Instr.I64Load, is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U,
is Node.Instr.I64Load16S, is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> popI32() + pushI64()
is Node.Instr.F32Load -> popI32() + pushF32()
is Node.Instr.F64Load -> popI32() + pushF64()
is Node.Instr.I32Store, is Node.Instr.I32Store8, is Node.Instr.I32Store16 -> popI32() + popI32()
is Node.Instr.I64Store, is Node.Instr.I64Store8,
is Node.Instr.I64Store16, is Node.Instr.I64Store32 -> popI64() + popI32()
is Node.Instr.F32Store -> popF32() + popI32()
is Node.Instr.F64Store -> popF64() + popI32()
is Node.Instr.MemorySize -> pushI32()
is Node.Instr.MemoryGrow -> popI32() + pushI32()
is Node.Instr.I32Const -> pushI32()
is Node.Instr.I64Const -> pushI64()
is Node.Instr.F32Const -> pushF32()
is Node.Instr.F64Const -> pushF64()
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
is Node.Instr.I32DivU, is Node.Instr.I32RemS, is Node.Instr.I32RemU, is Node.Instr.I32And,
is Node.Instr.I32Or, is Node.Instr.I32Xor, is Node.Instr.I32Shl, is Node.Instr.I32ShrS,
is Node.Instr.I32ShrU, is Node.Instr.I32Rotl, is Node.Instr.I32Rotr, is Node.Instr.I32Eq,
is Node.Instr.I32Ne, is Node.Instr.I32LtS, is Node.Instr.I32LeS, is Node.Instr.I32LtU,
is Node.Instr.I32LeU, is Node.Instr.I32GtS, is Node.Instr.I32GeS, is Node.Instr.I32GtU,
is Node.Instr.I32GeU -> popI32() + popI32() + pushI32()
is Node.Instr.I32Clz, is Node.Instr.I32Ctz, is Node.Instr.I32Popcnt,
is Node.Instr.I32Eqz -> popI32() + pushI32()
is Node.Instr.I64Add, is Node.Instr.I64Sub, is Node.Instr.I64Mul, is Node.Instr.I64DivS,
is Node.Instr.I64DivU, is Node.Instr.I64RemS, is Node.Instr.I64RemU, is Node.Instr.I64And,
is Node.Instr.I64Or, is Node.Instr.I64Xor, is Node.Instr.I64Shl, is Node.Instr.I64ShrS,
is Node.Instr.I64ShrU, is Node.Instr.I64Rotl, is Node.Instr.I64Rotr -> popI64() + popI64() + pushI64()
is Node.Instr.I64Eq, is Node.Instr.I64Ne, is Node.Instr.I64LtS, is Node.Instr.I64LeS,
is Node.Instr.I64LtU, is Node.Instr.I64LeU, is Node.Instr.I64GtS,
is Node.Instr.I64GeS, is Node.Instr.I64GtU, is Node.Instr.I64GeU -> popI64() + popI64() + pushI32()
is Node.Instr.I64Clz, is Node.Instr.I64Ctz, is Node.Instr.I64Popcnt -> popI64() + pushI64()
is Node.Instr.I64Eqz -> popI64() + pushI32()
is Node.Instr.F32Add, is Node.Instr.F32Sub, is Node.Instr.F32Mul, is Node.Instr.F32Div,
is Node.Instr.F32Min, is Node.Instr.F32Max, is Node.Instr.F32CopySign -> popF32() + popF32() + pushF32()
is Node.Instr.F32Eq, is Node.Instr.F32Ne, is Node.Instr.F32Lt, is Node.Instr.F32Le,
is Node.Instr.F32Gt, is Node.Instr.F32Ge -> popF32() + popF32() + pushI32()
is Node.Instr.F32Abs, is Node.Instr.F32Neg, is Node.Instr.F32Ceil, is Node.Instr.F32Floor,
is Node.Instr.F32Trunc, is Node.Instr.F32Nearest, is Node.Instr.F32Sqrt -> popF32() + pushF32()
is Node.Instr.F64Add, is Node.Instr.F64Sub, is Node.Instr.F64Mul, is Node.Instr.F64Div,
is Node.Instr.F64Min, is Node.Instr.F64Max, is Node.Instr.F64CopySign -> popF64() + popF64() + pushF64()
is Node.Instr.F64Eq, is Node.Instr.F64Ne, is Node.Instr.F64Lt, is Node.Instr.F64Le,
is Node.Instr.F64Gt, is Node.Instr.F64Ge -> popF64() + popF64() + pushI32()
is Node.Instr.F64Abs, is Node.Instr.F64Neg, is Node.Instr.F64Ceil, is Node.Instr.F64Floor,
is Node.Instr.F64Trunc, is Node.Instr.F64Nearest, is Node.Instr.F64Sqrt -> popF64() + pushF64()
is Node.Instr.I32WrapI64 -> popI64() + pushI32()
is Node.Instr.I32TruncSF32, is Node.Instr.I32TruncUF32,
is Node.Instr.I32ReinterpretF32 -> popF32() + pushI32()
is Node.Instr.I32TruncSF64, is Node.Instr.I32TruncUF64 -> popF64() + pushI32()
is Node.Instr.I64ExtendSI32, is Node.Instr.I64ExtendUI32 -> popI32() + pushI64()
is Node.Instr.I64TruncSF32, is Node.Instr.I64TruncUF32 -> popF32() + pushI64()
is Node.Instr.I64TruncSF64, is Node.Instr.I64TruncUF64,
is Node.Instr.I64ReinterpretF64 -> popF64() + pushI64()
is Node.Instr.F32ConvertSI32, is Node.Instr.F32ConvertUI32,
is Node.Instr.F32ReinterpretI32 -> popI32() + pushF32()
is Node.Instr.F32ConvertSI64, is Node.Instr.F32ConvertUI64 -> popI64() + pushF32()
is Node.Instr.F32DemoteF64 -> popF64() + pushF32()
is Node.Instr.F64ConvertSI32, is Node.Instr.F64ConvertUI32 -> popI32() + pushF64()
is Node.Instr.F64ConvertSI64, is Node.Instr.F64ConvertUI64,
is Node.Instr.F64ReinterpretI64 -> popI64() + pushF64()
is Node.Instr.F64PromoteF32 -> popF32() + pushF64()
}
}
protected fun insnApply(v: Node.Instr, fn: MutableList<Node.Type.Value>?.() -> List<InsnApplyResponse>): Stack {
val mutStack = current?.toMutableList()
val applyResp = mutStack.fn()
val newUnreachable = (applyResp.find { it is Unreachable } as? Unreachable)?.untilEndCount
return copy(
current = mutStack,
insnApplies = insnApplies + InsnApply(
insn = v,
stackAtBeginning = current,
stackChanges = applyResp.mapNotNull { it as? StackChange },
unreachableUntilEndCount = newUnreachable ?: unreachableUntilNextEndCount
),
unreachableUntilNextEndCount = newUnreachable ?: unreachableUntilNextEndCount
)
}
protected fun unreachable(untilEndCount: Int) = listOf(Unreachable(untilEndCount))
protected fun local(index: Int) = func?.let {
it.type.params.getOrNull(index) ?: it.locals.getOrNull(index - it.type.params.size)
}
protected fun global(index: Int) = mod?.let {
it.importGlobals.getOrNull(index)?.type?.contentType ?:
it.mod.globals.getOrNull(index - it.importGlobals.size)?.type?.contentType
}
protected fun func(index: Int) = mod?.let {
it.importFuncs.getOrNull(index)?.typeIndex?.let { i -> it.mod.types.getOrNull(i) } ?:
it.mod.funcs.getOrNull(index - it.importFuncs.size)?.type
}
protected fun nop() = emptyList<StackChange>()
protected fun MutableList<Node.Type.Value>?.popType(expecting: Node.Type.Value? = null) =
this?.takeIf {
it.isNotEmpty().also {
require(!strict || it) { "Expected $expecting got empty" }
}
}?.let {
removeAt(size - 1).takeIf { actual -> (expecting == null || actual == expecting).also {
require(!strict || it) { "Expected $expecting got $actual" }
} }
} ?: expecting
protected fun MutableList<Node.Type.Value>?.pop(expecting: Node.Type.Value? = null) =
listOf(StackChange(popType(expecting), true))
protected fun MutableList<Node.Type.Value>?.popI32() = pop(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.popI64() = pop(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.popF32() = pop(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.popF64() = pop(Node.Type.Value.F64)
protected fun MutableList<Node.Type.Value>?.push(type: Node.Type.Value? = null) =
listOf(StackChange(type, false)).also { if (this != null && type != null) add(type) }
protected fun MutableList<Node.Type.Value>?.pushI32() = push(Node.Type.Value.I32)
protected fun MutableList<Node.Type.Value>?.pushI64() = push(Node.Type.Value.I64)
protected fun MutableList<Node.Type.Value>?.pushF32() = push(Node.Type.Value.F32)
protected fun MutableList<Node.Type.Value>?.pushF64() = push(Node.Type.Value.F64)
data class InsnApply(
val insn: Node.Instr,
val stackAtBeginning: List<Node.Type.Value>?,
val stackChanges: List<StackChange>,
val unreachableUntilEndCount: Int
)
protected interface InsnApplyResponse
data class StackChange(
val type: Node.Type.Value?,
val pop: Boolean
) : InsnApplyResponse
data class Unreachable(
val untilEndCount: Int
) : InsnApplyResponse
class CachedModule(val mod: Node.Module) {
val importFuncs by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Func } }
val importGlobals by lazy { mod.imports.mapNotNull { it.kind as? Node.Import.Kind.Global } }
}
companion object {
fun walkStrict(mod: Node.Module, func: Node.Func, afterInsn: ((Stack, Node.Instr) -> Unit)? = null) =
func.instructions.fold(Stack(
mod = CachedModule(mod),
func = func,
current = emptyList(),
strict = true
)) { stack, insn -> stack.next(insn).also { afterInsn?.invoke(it, insn) } }.also { stack ->
// We expect to be in an unreachable state at the end or have the single return value on the stack
if (stack.unreachableUntilNextEndCount == 0) {
val expectedStack = (func.type.ret?.let { listOf(it) } ?: emptyList())
require(expectedStack == stack.current) {
"Expected end to be $expectedStack, got ${stack.current}"
}
}
}
fun stackChanges(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
Stack().next(v, callFuncType).insnApplies.last().stackChanges
fun stackChanges(mod: CachedModule, func: Node.Func, v: Node.Instr) =
Stack(mod, func).next(v).insnApplies.last().stackChanges
fun stackDiff(v: Node.Instr, callFuncType: Node.Type.Func? = null) =
stackChanges(v, callFuncType).sumBy { if (it.pop) -1 else 1 }
fun stackDiff(mod: CachedModule, func: Node.Func, v: Node.Instr) =
stackChanges(mod, func, v).sumBy { if (it.pop) -1 else 1 }
}
}

View File

@ -1,9 +1,9 @@
package asmble.cli
import asmble.ast.Script
import asmble.compile.jvm.AsmToBinary
import asmble.compile.jvm.AstToAsm
import asmble.compile.jvm.ClsContext
import asmble.compile.jvm.withComputedFramesAndMaxs
import java.io.FileOutputStream
@Suppress("NAME_SHADOWING")
@ -51,7 +51,7 @@ open class Compile : Command<Compile.Args>() {
val inFormat =
if (args.inFormat != "<use file extension>") args.inFormat
else args.inFile.substringAfterLast('.', "<unknown>")
val script = Translate.inToAst(args.inFile, inFormat)
val script = Translate().also { it.logger = logger }.inToAst(args.inFile, inFormat)
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module) ?:
error("Only a single sexpr for (module) allowed")
val outStream = when (args.outFile) {
@ -69,7 +69,7 @@ open class Compile : Command<Compile.Args>() {
includeBinary = args.includeBinary
)
AstToAsm.fromModule(ctx)
outStream.write(ctx.cls.withComputedFramesAndMaxs())
outStream.write(AsmToBinary.fromClassNode(ctx.cls))
}
}

View File

@ -3,6 +3,9 @@ package asmble.cli
import asmble.compile.jvm.javaIdent
import asmble.run.jvm.Module
/**
* This class provide ''invoke'' WASM code functionality.
*/
open class Invoke : ScriptCommand<Invoke.Args>() {
override val name = "invoke"
@ -34,6 +37,7 @@ open class Invoke : ScriptCommand<Invoke.Args>() {
).also { bld.done() }
override fun run(args: Args) {
// Compiles wasm to bytecode, do registrations and so on.
val ctx = prepareContext(args.scriptArgs)
// Instantiate the module
val module =
@ -41,11 +45,11 @@ open class Invoke : ScriptCommand<Invoke.Args>() {
else ctx.registrations[args.module] as? Module.Instance ?:
error("Unable to find module registered as ${args.module}")
// Just make sure the module is instantiated here...
module.instance(ctx)
val instance = module.instance(ctx)
// If an export is provided, call it
if (args.export != "<start-func>") args.export.javaIdent.let { javaName ->
val method = module.cls.declaredMethods.find { it.name == javaName } ?:
error("Unable to find export '${args.export}'")
// Finds java method(wasm fn) in class(wasm module) by name(declared in <start-func>)
val method = module.cls.declaredMethods.find { it.name == javaName } ?: error("Unable to find export '${args.export}'")
// Map args to params
require(method.parameterTypes.size == args.args.size) {
"Given arg count of ${args.args.size} is invalid for $method"
@ -59,11 +63,20 @@ open class Invoke : ScriptCommand<Invoke.Args>() {
else -> error("Unrecognized type for param ${index + 1}: $paramType")
}
}
val result = method.invoke(module.instance(ctx), *params.toTypedArray())
val result = method.invoke(instance, *params.toTypedArray())
if (args.resultToStdout && method.returnType != Void.TYPE) println(result)
}
}
/**
* Arguments for 'invoke' command.
*
* @param scriptArgs Common arguments for 'invoke' and 'run' ScriptCommands.
* @param module The module name to run. If it's a JVM class, it must have a no-arg constructor
* @param export The specific export function to invoke
* @param args Parameter for the export if export is present
* @param resultToStdout If true result will print to stout
*/
data class Args(
val scriptArgs: ScriptCommand.ScriptArgs,
val module: String,

View File

@ -1,7 +1,7 @@
package asmble.cli
import asmble.compile.jvm.AsmToBinary
import asmble.compile.jvm.Linker
import asmble.compile.jvm.withComputedFramesAndMaxs
import java.io.FileOutputStream
open class Link : Command<Link.Args>() {
@ -52,7 +52,7 @@ open class Link : Command<Link.Args>() {
defaultMaxMemPages = args.defaultMaxMem
)
Linker.link(ctx)
outStream.write(ctx.cls.withComputedFramesAndMaxs())
outStream.write(AsmToBinary.fromClassNode(ctx.cls))
}
}

View File

@ -5,6 +5,9 @@ import kotlin.system.exitProcess
val commands = listOf(Compile, Help, Invoke, Link, Run, Translate)
/**
* Entry point of command line interface.
*/
fun main(args: Array<String>) {
if (args.isEmpty()) return println(
"""
@ -28,6 +31,7 @@ fun main(args: Array<String>) {
val globals = Main.globalArgs(argBuild)
logger = Logger.Print(globals.logLevel)
command.logger = logger
logger.info { "Running the command=${command.name} with args=${argBuild.args}" }
command.runWithArgs(argBuild)
} catch (e: Exception) {
logger.error { "Error ${command?.let { "in command '${it.name}'" } ?: ""}: ${e.message}" }

View File

@ -1,10 +1,10 @@
package asmble.cli
import asmble.ast.Script
import asmble.compile.jvm.javaIdent
import asmble.run.jvm.Module
import asmble.run.jvm.ScriptContext
import asmble.compile.jvm.*
import asmble.run.jvm.*
import java.io.File
import java.io.PrintWriter
import java.util.*
abstract class ScriptCommand<T> : Command<T>() {
@ -41,48 +41,97 @@ abstract class ScriptCommand<T> : Command<T>() {
desc = "The maximum number of memory pages when a module doesn't say.",
default = "5",
lowPriority = true
).toInt()
).toInt(),
enableLogger = bld.arg(
name = "enableLogger",
opt = "enableLogger",
desc = "Enables the special module the could be used for logging",
default = "false",
lowPriority = true
).toBoolean()
)
fun prepareContext(args: ScriptArgs): ScriptContext {
var ctx = ScriptContext(
var context = ScriptContext(
packageName = "asmble.temp" + UUID.randomUUID().toString().replace("-", ""),
defaultMaxMemPages = args.defaultMaxMemPages
defaultMaxMemPages = args.defaultMaxMemPages,
memoryBuilder = args.memoryBuilder
)
// Compile everything
ctx = args.inFiles.foldIndexed(ctx) { index, ctx, inFile ->
context = args.inFiles.foldIndexed(context) { index, ctx, inFile ->
try {
when (inFile.substringAfterLast('.')) {
// if input file is class file
"class" -> ctx.classLoader.addClass(File(inFile).readBytes()).let { ctx }
else -> Translate.inToAst(inFile, inFile.substringAfterLast('.')).let { inAst ->
val (mod, name) = (inAst.commands.singleOrNull() as? Script.Cmd.Module) ?:
// if input file is wasm file
else -> {
val translateCmd = Translate
translateCmd.logger = this.logger
translateCmd.inToAst(inFile, inFile.substringAfterLast('.')).let { inAst ->
val (mod, name) = (inAst.commands.singleOrNull() as? Script.Cmd.Module) ?:
error("Input file must only contain a single module")
val className = name?.javaIdent?.capitalize() ?:
val className = name?.javaIdent?.capitalize() ?:
"Temp" + UUID.randomUUID().toString().replace("-", "")
ctx.withCompiledModule(mod, className, name).let { ctx ->
if (name == null && index != args.inFiles.size - 1)
logger.warn { "File '$inFile' not last and has no name so will be unused" }
if (name == null || args.disableAutoRegister) ctx
else ctx.runCommand(Script.Cmd.Register(name, null))
ctx.withCompiledModule(mod, className, name).let { ctx ->
if (name == null && index != args.inFiles.size - 1)
logger.warn { "File '$inFile' not last and has no name so will be unused" }
if (name == null || args.disableAutoRegister) ctx
else ctx.runCommand(Script.Cmd.Register(name, null))
}
}
}
}
} catch (e: Exception) { throw Exception("Failed loading $inFile - ${e.message}", e) }
} catch (e: Exception) {
throw Exception("Failed loading $inFile - ${e.message}", e)
}
}
// Do registrations
ctx = args.registrations.fold(ctx) { ctx, (moduleName, className) ->
context = args.registrations.fold(context) { ctx, (moduleName, className) ->
ctx.withModuleRegistered(moduleName,
Module.Native(Class.forName(className, true, ctx.classLoader).newInstance()))
Module.Native(Class.forName(className, true, ctx.classLoader).newInstance()))
}
if (args.specTestRegister) ctx = ctx.withHarnessRegistered()
return ctx
if (args.specTestRegister) context = context.withHarnessRegistered()
if (args.enableLogger) {
// add logger Wasm module for logging
context =
context.withModuleRegistered(
"logger",
Module.Native(LoggerModule(PrintWriter(System.out)))
)
}
// add env Wasm module for gas metering
context =
context.withModuleRegistered(
"env",
// TODO: currently we are using almost infinite gas limit
Module.Native(EnvModule(Long.MAX_VALUE))
)
return context
}
/**
* Common arguments for 'invoke' and 'run' ScriptCommands.
*
* @param inFiles Files to add to classpath. Can be wasm, wast, or class file
* @param registrations Register class name to a module name
* @param disableAutoRegister If set, this will not auto-register modules with names
* @param specTestRegister If true, registers the spec test harness as 'spectest'
* @param defaultMaxMemPages The maximum number of memory pages when a module doesn't say
* @param enableLogger If set, the special logger module will be registred.
* @param memoryBuilder The builder to initialize new memory class.
*/
data class ScriptArgs(
val inFiles: List<String>,
val registrations: List<Pair<String, String>>,
val disableAutoRegister: Boolean,
val specTestRegister: Boolean,
val defaultMaxMemPages: Int
val defaultMaxMemPages: Int,
val enableLogger: Boolean,
val memoryBuilder: MemoryBufferBuilder? = null
)
}
}

View File

@ -52,30 +52,16 @@ open class Translate : Command<Translate.Args>() {
if (args.outFormat != "<use file extension or wast for stdout>") args.outFormat
else if (args.outFile == "--") "wast"
else args.outFile.substringAfterLast('.', "<unknown>")
val outStream =
if (args.outFile == "--") System.out
else FileOutputStream(args.outFile)
outStream.use { outStream ->
when (outFormat) {
"wast" -> {
val sexprToStr = if (args.compact) SExprToStr.Compact else SExprToStr
val sexprs = AstToSExpr.fromScript(script)
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
}
"wasm" -> {
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
error("Output to WASM requires input be just a single module")
AstToBinary.fromModule(ByteWriter.OutputStream(outStream), mod)
}
else -> error("Unknown out format '$outFormat'")
}
}
astToOut(args.outFile, outFormat, args.compact, script)
}
fun inToAst(inFile: String, inFormat: String): Script {
val inBytes =
if (inFile == "--") System.`in`.use { it.readBytes() }
else File(inFile).let { f -> FileInputStream(f).use { it.readBytes(f.length().toIntExact()) } }
else File(inFile).let { f ->
// Input file might not fit into the memory
FileInputStream(f).use { it.readBytes(f.length().toIntExact()) }
}
return when (inFormat) {
"wast" -> StrToSExpr.parse(inBytes.toString(Charsets.UTF_8)).let { res ->
when (res) {
@ -84,12 +70,35 @@ open class Translate : Command<Translate.Args>() {
}
}
"wasm" ->
Script(listOf(Script.Cmd.Module(BinaryToAst.toModule(
ByteReader.InputStream(inBytes.inputStream())), null)))
BinaryToAst(logger = this.logger).toModule(
ByteReader.InputStream(inBytes.inputStream())).let { module ->
Script(listOf(Script.Cmd.Module(module, module.names?.moduleName)))
}
else -> error("Unknown in format '$inFormat'")
}
}
fun astToOut(outFile: String, outFormat: String, compact: Boolean, script: Script) {
val outStream =
if (outFile == "--") System.out
else FileOutputStream(outFile)
outStream.use { outStream ->
when (outFormat) {
"wast" -> {
val sexprToStr = if (compact) SExprToStr.Compact else SExprToStr
val sexprs = AstToSExpr.fromScript(script)
outStream.write(sexprToStr.fromSExpr(*sexprs.toTypedArray()).toByteArray())
}
"wasm" -> {
val mod = (script.commands.firstOrNull() as? Script.Cmd.Module)?.module ?:
error("Output to WASM requires input be just a single module")
AstToBinary.fromModule(ByteWriter.OutputStream(outStream), mod)
}
else -> error("Unknown out format '$outFormat'")
}
}
}
data class Args(
val inFile: String,
val inFormat: String,

View File

@ -2,7 +2,6 @@ package asmble.compile.jvm
import asmble.ast.Node
import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.tree.*
@ -189,16 +188,6 @@ fun MethodNode.toAsmString(): String {
val Node.Type.Func.asmDesc: String get() =
(this.ret?.typeRef ?: Void::class.ref).asMethodRetDesc(*this.params.map { it.typeRef }.toTypedArray())
fun ClassNode.withComputedFramesAndMaxs(
cw: ClassWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
): ByteArray {
// Note, compute maxs adds a bunch of NOPs for unreachable code.
// See $func12 of block.wast. I don't believe the extra time over the
// instructions to remove the NOPs is worth it.
this.accept(cw)
return cw.toByteArray()
}
fun ClassNode.toAsmString(): String {
val stringWriter = StringWriter()
this.accept(TraceClassVisitor(PrintWriter(stringWriter)))
@ -210,3 +199,7 @@ fun ByteArray.asClassNode(): ClassNode {
ClassReader(this).accept(newNode, 0)
return newNode
}
fun ByteArray.chunked(v: Int) = (0 until size step v).asSequence().map {
copyOfRange(it, (it + v).takeIf { it < size } ?: size)
}

View File

@ -0,0 +1,51 @@
package asmble.compile.jvm
import asmble.compile.jvm.msplit.SplitMethod
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodTooLargeException
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.ClassNode
/**
* May mutate given class nodes on [fromClassNode] if [splitMethod] is present (the default). Uses the two-param
* [SplitMethod.split] call to try and split overly large methods.
*/
open class AsmToBinary(val splitMethod: SplitMethod? = SplitMethod(Opcodes.ASM6)) {
fun fromClassNode(
cn: ClassNode,
newClassWriter: () -> ClassWriter = { ClassWriter(ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) }
): ByteArray {
while (true) {
val cw = newClassWriter()
// Note, compute maxs adds a bunch of NOPs for unreachable code.
// See $func12 of block.wast. I don't believe the extra time over the
// instructions to remove the NOPs is worth it.
cn.accept(cw)
try {
return cw.toByteArray()
} catch (e: MethodTooLargeException) {
if (splitMethod == null) throw e
// Split the offending method by removing it and replacing it with the split ones
require(cn.name == e.className)
val tooLargeIndex = cn.methods.indexOfFirst { it.name == e.methodName && it.desc == e.descriptor }
require(tooLargeIndex >= 0)
val split = splitMethod.split(cn.name, cn.methods[tooLargeIndex])
split ?: throw IllegalStateException("Failed to split", e)
// Change the split off method's name if there's already one
val origName = split.splitOffMethod.name
var foundCount = 0
while (cn.methods.any { it.name == split.splitOffMethod.name }) {
split.splitOffMethod.name = origName + (++foundCount)
}
// Replace at the index
cn.methods.removeAt(tooLargeIndex)
cn.methods.add(tooLargeIndex, split.splitOffMethod)
cn.methods.add(tooLargeIndex, split.trimmedMethod)
}
}
}
companion object : AsmToBinary() {
val noSplit = AsmToBinary(null)
}
}

View File

@ -30,10 +30,11 @@ open class AstToAsm {
}
fun addFields(ctx: ClsContext) {
// Mem field if present
// Mem field if present, adds `private final field memory` to
if (ctx.hasMemory)
ctx.cls.fields.add(FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, "memory",
ctx.mem.memType.asmDesc, null, null))
ctx.cls.fields.add(
FieldNode((Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL), "memory", ctx.mem.memType.asmDesc, null, null)
)
// Table field if present...
// Private final for now, but likely won't be final in future versions supporting
// mutable tables, may be not even a table but a list (and final)
@ -47,10 +48,13 @@ open class AstToAsm {
})
// Now all import globals as getter (and maybe setter) method handles
ctx.cls.fields.addAll(ctx.importGlobals.mapIndexed { index, import ->
if ((import.kind as Node.Import.Kind.Global).type.mutable) throw CompileErr.MutableGlobalImport(index)
FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalGetterFieldName(index),
val getter = FieldNode(Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalGetterFieldName(index),
MethodHandle::class.ref.asmDesc, null, null)
})
if (!(import.kind as Node.Import.Kind.Global).type.mutable) listOf(getter)
else listOf(getter, FieldNode(
Opcodes.ACC_PRIVATE + Opcodes.ACC_FINAL, ctx.importGlobalSetterFieldName(index),
MethodHandle::class.ref.asmDesc, null, null))
}.flatten())
// Now all non-import globals
ctx.cls.fields.addAll(ctx.mod.globals.mapIndexed { index, global ->
val access = Opcodes.ACC_PRIVATE + if (!global.type.mutable) Opcodes.ACC_FINAL else 0
@ -180,9 +184,11 @@ open class AstToAsm {
fun constructorImportTypes(ctx: ClsContext) =
ctx.importFuncs.map { MethodHandle::class.ref } +
// We know it's only getters
ctx.importGlobals.map { MethodHandle::class.ref } +
ctx.mod.imports.filter { it.kind is Node.Import.Kind.Table }.map { Array<MethodHandle>::class.ref }
ctx.importGlobals.flatMap {
// If it's mutable, it also comes with a setter
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) listOf(MethodHandle::class.ref)
else listOf(MethodHandle::class.ref, MethodHandle::class.ref)
} + ctx.mod.imports.filter { it.kind is Node.Import.Kind.Table }.map { Array<MethodHandle>::class.ref }
fun toConstructorNode(ctx: ClsContext, func: Func) = mutableListOf<List<AnnotationNode>>().let { paramAnns ->
// If the first param is a mem class and imported, add annotation
@ -199,7 +205,15 @@ open class AstToAsm {
}
// All non-mem imports one after another
ctx.importFuncs.forEach { paramAnns.add(listOf(importAnnotation(ctx, it))) }
ctx.importGlobals.forEach { paramAnns.add(listOf(importAnnotation(ctx, it))) }
ctx.importGlobals.forEach {
paramAnns.add(listOf(importAnnotation(ctx, it)))
// There are two annotations here if it's mutable
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == true)
paramAnns.add(listOf(importAnnotation(ctx, it).also {
it.values.add("globalSetter")
it.values.add(true)
}))
}
ctx.mod.imports.forEach {
if (it.kind is Node.Import.Kind.Table) paramAnns.add(listOf(importAnnotation(ctx, it)))
}
@ -240,14 +254,25 @@ open class AstToAsm {
}
fun setConstructorGlobalImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.importGlobals.indices.fold(func) { func, importIndex ->
ctx.importGlobals.foldIndexed(func to ctx.importFuncs.size + paramsBeforeImports) {
importIndex, (func, importParamOffset), import ->
// Always a getter handle
func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, ctx.importFuncs.size + importIndex + paramsBeforeImports + 1),
VarInsnNode(Opcodes.ALOAD, importParamOffset + 1),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName,
ctx.importGlobalGetterFieldName(importIndex), MethodHandle::class.ref.asmDesc)
)
}
).let { func ->
// If it's mutable, it has a second setter handle
if ((import.kind as? Node.Import.Kind.Global)?.type?.mutable == false) func to importParamOffset + 1
else func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, importParamOffset + 2),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName,
ctx.importGlobalSetterFieldName(importIndex), MethodHandle::class.ref.asmDesc)
) to importParamOffset + 2
}
}.first
fun setConstructorFunctionImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
ctx.importFuncs.indices.fold(func) { func, importIndex ->
@ -261,7 +286,10 @@ open class AstToAsm {
fun setConstructorTableImports(ctx: ClsContext, func: Func, paramsBeforeImports: Int) =
if (ctx.mod.imports.none { it.kind is Node.Import.Kind.Table }) func else {
val importIndex = ctx.importFuncs.size + ctx.importGlobals.size + paramsBeforeImports + 1
val importIndex = ctx.importFuncs.size +
// Mutable global imports have setters and take up two spots
ctx.importGlobals.sumBy { if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == true) 2 else 1 } +
paramsBeforeImports + 1
func.addInsns(
VarInsnNode(Opcodes.ALOAD, 0),
VarInsnNode(Opcodes.ALOAD, importIndex),
@ -299,11 +327,14 @@ open class AstToAsm {
global.type.contentType.typeRef,
refGlobalKind.type.contentType.typeRef
)
val paramOffset = ctx.importFuncs.size + paramsBeforeImports + 1 +
ctx.importGlobals.take(it.index).sumBy {
// Immutable jumps 1, mutable jumps 2
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1
else 2
}
listOf(
VarInsnNode(
Opcodes.ALOAD,
ctx.importFuncs.size + it.index + paramsBeforeImports + 1
),
VarInsnNode(Opcodes.ALOAD, paramOffset),
MethodInsnNode(
Opcodes.INVOKEVIRTUAL,
MethodHandle::class.ref.asmName,
@ -356,7 +387,10 @@ open class AstToAsm {
// Otherwise, it was imported and we can set the elems on the imported one
// from the parameter
// TODO: I think this is a security concern and bad practice, may revisit
val importIndex = ctx.importFuncs.size + ctx.importGlobals.size + paramsBeforeImports + 1
val importIndex = ctx.importFuncs.size + ctx.importGlobals.sumBy {
// Immutable is 1, mutable is 2
if ((it.kind as? Node.Import.Kind.Global)?.type?.mutable == false) 1 else 2
} + paramsBeforeImports + 1
return func.addInsns(VarInsnNode(Opcodes.ALOAD, importIndex)).
let { func -> addElemsToTable(ctx, func, paramsBeforeImports) }.
// Remove the array that's still there
@ -532,28 +566,58 @@ open class AstToAsm {
is Either.Left -> (global.v.kind as Node.Import.Kind.Global).type
is Either.Right -> global.v.type
}
if (type.mutable) throw CompileErr.MutableGlobalExport(export.index)
// Create a simple getter
val method = MethodNode(Opcodes.ACC_PUBLIC, "get" + export.field.javaIdent.capitalize(),
val getter = MethodNode(Opcodes.ACC_PUBLIC, "get" + export.field.javaIdent.capitalize(),
"()" + type.contentType.typeRef.asmDesc, null, null)
method.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) method.addInsns(
getter.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) getter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName,
ctx.importGlobalGetterFieldName(export.index), MethodHandle::class.ref.asmDesc),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact",
"()" + type.contentType.typeRef.asmDesc, false)
) else method.addInsns(
) else getter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName, ctx.globalName(export.index),
type.contentType.typeRef.asmDesc)
)
method.addInsns(InsnNode(when (type.contentType) {
getter.addInsns(InsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.IRETURN
Node.Type.Value.I64 -> Opcodes.LRETURN
Node.Type.Value.F32 -> Opcodes.FRETURN
Node.Type.Value.F64 -> Opcodes.DRETURN
}))
method.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(method)
getter.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(getter)
// If mutable, create simple setter
if (type.mutable) {
val setter = MethodNode(Opcodes.ACC_PUBLIC, "set" + export.field.javaIdent.capitalize(),
"(${type.contentType.typeRef.asmDesc})V", null, null)
setter.addInsns(VarInsnNode(Opcodes.ALOAD, 0))
if (global is Either.Left) setter.addInsns(
FieldInsnNode(Opcodes.GETFIELD, ctx.thisRef.asmName,
ctx.importGlobalSetterFieldName(export.index), MethodHandle::class.ref.asmDesc),
VarInsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.ILOAD
Node.Type.Value.I64 -> Opcodes.LLOAD
Node.Type.Value.F32 -> Opcodes.FLOAD
Node.Type.Value.F64 -> Opcodes.DLOAD
}, 1),
MethodInsnNode(Opcodes.INVOKEVIRTUAL, MethodHandle::class.ref.asmName, "invokeExact",
"(${type.contentType.typeRef.asmDesc})V", false),
InsnNode(Opcodes.RETURN)
) else setter.addInsns(
VarInsnNode(when (type.contentType) {
Node.Type.Value.I32 -> Opcodes.ILOAD
Node.Type.Value.I64 -> Opcodes.LLOAD
Node.Type.Value.F32 -> Opcodes.FLOAD
Node.Type.Value.F64 -> Opcodes.DLOAD
}, 1),
FieldInsnNode(Opcodes.PUTFIELD, ctx.thisRef.asmName, ctx.globalName(export.index),
type.contentType.typeRef.asmDesc),
InsnNode(Opcodes.RETURN)
)
setter.visibleAnnotations = listOf(exportAnnotation(export))
ctx.cls.methods.plusAssign(setter)
}
}
fun addExportMemory(ctx: ClsContext, export: Node.Export) {

View File

@ -4,32 +4,30 @@ import asmble.ast.Node
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.tree.*
import java.nio.Buffer
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
open class ByteBufferMem(val direct: Boolean = true) : Mem {
override val memType = ByteBuffer::class.ref
open class ByteBufferMem : Mem {
override val memType: TypeRef = MemoryBuffer::class.ref
override fun limitAndCapacity(instance: Any) =
if (instance !is ByteBuffer) error("Unrecognized memory instance: $instance")
override fun limitAndCapacity(instance: Any): Pair<Int, Int> =
if (instance !is MemoryBuffer) error("Unrecognized memory instance: $instance")
else instance.limit() to instance.capacity()
override fun create(func: Func) = func.popExpecting(Int::class.ref).addInsns(
(if (direct) ByteBuffer::allocateDirect else ByteBuffer::allocate).invokeStatic()
(MemoryBuffer::init).invokeStatic()
).push(memType)
override fun init(func: Func, initial: Int) = func.popExpecting(memType).addInsns(
// Set the limit to initial
(initial * Mem.PAGE_SIZE).const,
forceFnType<ByteBuffer.(Int) -> Buffer>(ByteBuffer::limit).invokeVirtual(),
TypeInsnNode(Opcodes.CHECKCAST, ByteBuffer::class.ref.asmName),
forceFnType<MemoryBuffer.(Int) -> MemoryBuffer>(MemoryBuffer::limit).invokeVirtual(),
TypeInsnNode(Opcodes.CHECKCAST, memType.asmName),
// Set it to use little endian
ByteOrder::LITTLE_ENDIAN.getStatic(),
forceFnType<ByteBuffer.(ByteOrder) -> ByteBuffer>(ByteBuffer::order).invokeVirtual()
).push(ByteBuffer::class.ref)
forceFnType<MemoryBuffer.(ByteOrder) -> MemoryBuffer>(MemoryBuffer::order).invokeVirtual()
).push(memType)
override fun data(func: Func, bytes: ByteArray, buildOffset: (Func) -> Func) =
// Sadly there is no absolute bulk put, so we need to fake one. Ref:
@ -42,25 +40,34 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
// where we could call put directly, but it too is negligible for now.
// Note, with this approach, the mem not be left on the stack for future data() calls which is fine.
func.popExpecting(memType).
addInsns(ByteBuffer::duplicate.invokeVirtual()).
addInsns(MemoryBuffer::duplicate.invokeVirtual()).
let(buildOffset).popExpecting(Int::class.ref).
addInsns(
forceFnType<ByteBuffer.(Int) -> Buffer>(ByteBuffer::position).invokeVirtual(),
TypeInsnNode(Opcodes.CHECKCAST, memType.asmName),
// We're going to do this as an LDC string in ISO-8859 and read it back at runtime
LdcInsnNode(bytes.toString(Charsets.ISO_8859_1)),
LdcInsnNode("ISO-8859-1"),
// Ug, can't do func refs on native types here...
MethodInsnNode(Opcodes.INVOKEVIRTUAL, String::class.ref.asmName,
"getBytes", "(Ljava/lang/String;)[B", false),
0.const,
bytes.size.const,
forceFnType<ByteBuffer.(ByteArray, Int, Int) -> ByteBuffer>(ByteBuffer::put).invokeVirtual(),
forceFnType<MemoryBuffer.(Int) -> MemoryBuffer>(MemoryBuffer::position).invokeVirtual(),
TypeInsnNode(Opcodes.CHECKCAST, memType.asmName)
).addInsns(
// We're going to do this as an LDC string in ISO-8859 and read it back at runtime. However,
// due to JVM limits, we can't have a string > 65536 chars. We chunk into 16300 because when
// converting to UTF8 const it can be up to 4 bytes per char, so this makes sure it doesn't
// overflow.
bytes.chunked(16300).flatMap { bytes ->
sequenceOf(
LdcInsnNode(bytes.toString(Charsets.ISO_8859_1)),
LdcInsnNode("ISO-8859-1"),
// Ug, can't do func refs on native types here...
MethodInsnNode(Opcodes.INVOKEVIRTUAL, String::class.ref.asmName,
"getBytes", "(Ljava/lang/String;)[B", false),
0.const,
bytes.size.const,
forceFnType<MemoryBuffer.(ByteArray, Int, Int) -> MemoryBuffer>(MemoryBuffer::put).invokeVirtual()
)
}.toList()
).addInsns(
InsnNode(Opcodes.POP)
)
override fun currentMemory(ctx: FuncContext, func: Func) = func.popExpecting(memType).addInsns(
forceFnType<ByteBuffer.() -> Int>(ByteBuffer::limit).invokeVirtual(),
forceFnType<MemoryBuffer.() -> Int>(MemoryBuffer::limit).invokeVirtual(),
Mem.PAGE_SIZE.const,
InsnNode(Opcodes.IDIV)
).push(Int::class.ref)
@ -77,10 +84,10 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
val okLim = LabelNode()
val node = MethodNode(
Opcodes.ACC_PRIVATE + Opcodes.ACC_STATIC + Opcodes.ACC_SYNTHETIC,
"\$\$growMemory", "(Ljava/nio/ByteBuffer;I)I", null, null
"\$\$growMemory", "(Lasmble/compile/jvm/MemoryBuffer;I)I", null, null
).addInsns(
VarInsnNode(Opcodes.ALOAD, 0), // [mem]
forceFnType<ByteBuffer.() -> Int>(ByteBuffer::limit).invokeVirtual(), // [lim]
forceFnType<MemoryBuffer.() -> Int>(MemoryBuffer::limit).invokeVirtual(), // [lim]
InsnNode(Opcodes.DUP), // [lim, lim]
VarInsnNode(Opcodes.ALOAD, 0), // [lim, lim, mem]
InsnNode(Opcodes.SWAP), // [lim, mem, lim]
@ -93,7 +100,7 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
InsnNode(Opcodes.LADD), // [lim, mem, newlimL]
InsnNode(Opcodes.DUP2), // [lim, mem, newlimL, newlimL]
VarInsnNode(Opcodes.ALOAD, 0), // [lim, mem, newlimL, newlimL, mem]
ByteBuffer::capacity.invokeVirtual(), // [lim, mem, newlimL, newlimL, cap]
MemoryBuffer::capacity.invokeVirtual(), // [lim, mem, newlimL, newlimL, cap]
InsnNode(Opcodes.I2L), // [lim, mem, newlimL, newlimL, capL]
InsnNode(Opcodes.LCMP), // [lim, mem, newlimL, cmpres]
JumpInsnNode(Opcodes.IFLE, okLim), // [lim, mem, newlimL]
@ -102,7 +109,7 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
InsnNode(Opcodes.IRETURN),
okLim, // [lim, mem, newlimL]
InsnNode(Opcodes.L2I), // [lim, mem, newlim]
forceFnType<ByteBuffer.(Int) -> Buffer>(ByteBuffer::limit).invokeVirtual(), // [lim, mem]
forceFnType<MemoryBuffer.(Int) -> MemoryBuffer>(MemoryBuffer::limit).invokeVirtual(), // [lim, mem]
InsnNode(Opcodes.POP), // [lim]
Mem.PAGE_SIZE.const, // [lim, pagesize]
InsnNode(Opcodes.IDIV), // [limpages]
@ -116,7 +123,7 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
// Ug, some tests expect this to be a runtime failure so we feature flagged it
if (ctx.cls.eagerFailLargeMemOffset)
require(insn.offset <= Int.MAX_VALUE, { "Offsets > ${Int.MAX_VALUE} unsupported" }).let { this }
fun Func.load(fn: ByteBuffer.(Int) -> Any, retClass: KClass<*>) =
fun Func.load(fn: MemoryBuffer.(Int) -> Any, retClass: KClass<*>) =
this.popExpecting(Int::class.ref).let { func ->
// No offset means we'll access it directly
(if (insn.offset == 0L) func else {
@ -132,9 +139,9 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
}
}).popExpecting(memType).addInsns((fn as KFunction<*>).invokeVirtual())
}.push(retClass.ref)
fun Func.loadI32(fn: ByteBuffer.(Int) -> Any) =
fun Func.loadI32(fn: MemoryBuffer.(Int) -> Any) =
this.load(fn, Int::class)
fun Func.loadI64(fn: ByteBuffer.(Int) -> Any) =
fun Func.loadI64(fn: MemoryBuffer.(Int) -> Any) =
this.load(fn, Long::class)
/* Ug: https://youtrack.jetbrains.com/issue/KT-17064
fun Func.toUnsigned(fn: KFunction<*>) =
@ -154,33 +161,33 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
// Had to move this in here instead of as first expr because of https://youtrack.jetbrains.com/issue/KT-8689
return when (insn) {
is Node.Instr.I32Load ->
func.loadI32(ByteBuffer::getInt)
func.loadI32(MemoryBuffer::getInt)
is Node.Instr.I64Load ->
func.loadI64(ByteBuffer::getLong)
func.loadI64(MemoryBuffer::getLong)
is Node.Instr.F32Load ->
func.load(ByteBuffer::getFloat, Float::class)
func.load(MemoryBuffer::getFloat, Float::class)
is Node.Instr.F64Load ->
func.load(ByteBuffer::getDouble, Double::class)
func.load(MemoryBuffer::getDouble, Double::class)
is Node.Instr.I32Load8S ->
func.loadI32(ByteBuffer::get)
func.loadI32(MemoryBuffer::get)
is Node.Instr.I32Load8U ->
func.loadI32(ByteBuffer::get).toUnsigned32(java.lang.Byte::class, "toUnsignedInt", Byte::class)
func.loadI32(MemoryBuffer::get).toUnsigned32(java.lang.Byte::class, "toUnsignedInt", Byte::class)
is Node.Instr.I32Load16S ->
func.loadI32(ByteBuffer::getShort)
func.loadI32(MemoryBuffer::getShort)
is Node.Instr.I32Load16U ->
func.loadI32(ByteBuffer::getShort).toUnsigned32(java.lang.Short::class, "toUnsignedInt", Short::class)
func.loadI32(MemoryBuffer::getShort).toUnsigned32(java.lang.Short::class, "toUnsignedInt", Short::class)
is Node.Instr.I64Load8S ->
func.loadI32(ByteBuffer::get).i32ToI64()
func.loadI32(MemoryBuffer::get).i32ToI64()
is Node.Instr.I64Load8U ->
func.loadI32(ByteBuffer::get).toUnsigned64(java.lang.Byte::class, "toUnsignedLong", Byte::class)
func.loadI32(MemoryBuffer::get).toUnsigned64(java.lang.Byte::class, "toUnsignedLong", Byte::class)
is Node.Instr.I64Load16S ->
func.loadI32(ByteBuffer::getShort).i32ToI64()
func.loadI32(MemoryBuffer::getShort).i32ToI64()
is Node.Instr.I64Load16U ->
func.loadI32(ByteBuffer::getShort).toUnsigned64(java.lang.Short::class, "toUnsignedLong", Short::class)
func.loadI32(MemoryBuffer::getShort).toUnsigned64(java.lang.Short::class, "toUnsignedLong", Short::class)
is Node.Instr.I64Load32S ->
func.loadI32(ByteBuffer::getInt).i32ToI64()
func.loadI32(MemoryBuffer::getInt).i32ToI64()
is Node.Instr.I64Load32U ->
func.loadI32(ByteBuffer::getInt).toUnsigned64(java.lang.Integer::class, "toUnsignedLong", Int::class)
func.loadI32(MemoryBuffer::getInt).toUnsigned64(java.lang.Integer::class, "toUnsignedLong", Int::class)
else -> throw IllegalArgumentException("Unknown load op $insn")
}
}
@ -215,12 +222,12 @@ open class ByteBufferMem(val direct: Boolean = true) : Mem {
popExpecting(Int::class.ref).
popExpecting(memType).
addInsns(fn).
push(ByteBuffer::class.ref)
push(memType)
}
// Ug, I hate these as strings but can't introspect Kotlin overloads
fun bufStoreFunc(name: String, valType: KClass<*>) =
MethodInsnNode(Opcodes.INVOKEVIRTUAL, ByteBuffer::class.ref.asmName, name,
ByteBuffer::class.ref.asMethodRetDesc(Int::class.ref, valType.ref), false)
MethodInsnNode(Opcodes.INVOKEVIRTUAL, memType.asmName, name,
memType.asMethodRetDesc(Int::class.ref, valType.ref), false)
fun Func.changeI64ToI32() =
this.popExpecting(Long::class.ref).push(Int::class.ref)
when (insn) {

View File

@ -39,6 +39,24 @@ data class ClsContext(
val hasTable: Boolean by lazy {
mod.tables.isNotEmpty() || mod.imports.any { it.kind is Node.Import.Kind.Table }
}
val dedupedFuncNames: Map<Int, String>? by lazy {
// Consider all exports as seen
val seen = mod.exports.flatMap { export ->
when {
export.kind == Node.ExternalKind.FUNCTION -> listOf(export.field.javaIdent)
// Just to make it easy, consider all globals as having setters
export.kind == Node.ExternalKind.GLOBAL ->
export.field.javaIdent.capitalize().let { listOf("get$it", "set$it") }
else -> listOf("get" + export.field.javaIdent.capitalize())
}
}.toMutableSet()
mod.names?.funcNames?.toList()?.sortedBy { it.first }?.map { (index, origName) ->
var name = origName.javaIdent
var nameIndex = 0
while (!seen.add(name)) name = origName.javaIdent + (nameIndex++)
index to name
}?.toMap()
}
fun assertHasMemory() { if (!hasMemory) throw CompileErr.UnknownMemory(0) }
@ -71,7 +89,7 @@ data class ClsContext(
fun importGlobalGetterFieldName(index: Int) = "import\$get" + globalName(index)
fun importGlobalSetterFieldName(index: Int) = "import\$set" + globalName(index)
fun globalName(index: Int) = "\$global$index"
fun funcName(index: Int) = "\$func$index"
fun funcName(index: Int) = dedupedFuncNames?.get(index) ?: "\$func$index"
private fun syntheticFunc(
nameSuffix: String,

View File

@ -102,18 +102,6 @@ sealed class CompileErr(message: String, cause: Throwable? = null) : RuntimeExce
override val asmErrString get() = "global is immutable"
}
class MutableGlobalImport(
val index: Int
) : CompileErr("Attempted to import mutable global at index $index") {
override val asmErrString get() = "mutable globals cannot be imported"
}
class MutableGlobalExport(
val index: Int
) : CompileErr("Attempted to export global $index which is mutable") {
override val asmErrString get() = "mutable globals cannot be exported"
}
class GlobalInitNotConstant(
val index: Int
) : CompileErr("Expected init for global $index to be single constant value") {

View File

@ -4,6 +4,24 @@ import asmble.ast.Node
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.*
/**
* Jvm representation of a function.
*
* @param name Name of the fn.
* @param params List of parameters of the fn.
* @param ret Type of the fn returner value.
* @param access The value of the access_flags item is a mask of flags used to
* denote access permissions to and properties of this class or
* interface [https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.1-200-E.1].
* @param insns List of nodes that represents a bytecode instruction.
* @param stack A stack of operand types. Mirror of the operand stack(jvm stack)
* where types of operands instead operands.
* @param blockStack List of blocks of code
* @param ifStack Contains index of [org.objectweb.asm.tree.JumpInsnNode] that
* has a null label initially
* @param lastStackIsMemLeftover If there is the memory on the stack and we need it
* in the future, we mark it as leftover and reuse
*/
data class Func(
val name: String,
val params: List<TypeRef> = emptyList(),
@ -12,7 +30,6 @@ data class Func(
val insns: List<AbstractInsnNode> = emptyList(),
val stack: List<TypeRef> = emptyList(),
val blockStack: List<Block> = emptyList(),
// Contains index of JumpInsnNode that has a null label initially
val ifStack: List<Int> = emptyList(),
val lastStackIsMemLeftover: Boolean = false
) {
@ -110,10 +127,11 @@ data class Func(
}
}
fun pushBlock(insn: Node.Instr, labelType: Node.Type.Value?, endType: Node.Type.Value?) =
/** Creates new block with specified instruction and pushes it into the blockStack.*/
fun pushBlock(insn: Node.Instr, labelType: Node.Type.Value?, endType: Node.Type.Value?): Func =
pushBlock(insn, listOfNotNull(labelType?.typeRef), listOfNotNull(endType?.typeRef))
fun pushBlock(insn: Node.Instr, labelTypes: List<TypeRef>, endTypes: List<TypeRef>) =
fun pushBlock(insn: Node.Instr, labelTypes: List<TypeRef>, endTypes: List<TypeRef>): Func =
copy(blockStack = blockStack + Block(insn, insns.size, stack, labelTypes, endTypes))
fun popBlock() = copy(blockStack = blockStack.dropLast(1)) to blockStack.last()
@ -127,6 +145,22 @@ data class Func(
fun popIf() = copy(ifStack = ifStack.dropLast(1)) to peekIf()
/**
* Representation of code block.
*
* Blocks are composed of matched pairs of `block ... end` instructions, loops
* with matched pairs of `loop ... end` instructions, and ifs with either
* `if ... end` or if ... else ... end sequences. For each of these constructs
* the instructions in the ellipsis are said to be enclosed in the construct.
*
* @param isns Start instruction of this block, might be a 'Block', 'Loop'
* or 'If'
* @param startIndex Index of start instruction of this block in list of all
* instructions
* @param origStack Current block stack of operand types.
* @param labelTypes A type of label for this block
* @param endTypes A type of block return value
*/
class Block(
val insn: Node.Instr,
val startIndex: Int,

View File

@ -14,23 +14,31 @@ import java.lang.invoke.MethodHandle
// TODO: modularize
open class FuncBuilder {
fun fromFunc(ctx: ClsContext, f: Node.Func, index: Int): Func {
/**
* Converts wasm AST [asmble.ast.Node.Func] to Jvm bytecode representation [asmble.compile.jvm.Func].
*
* @param ctx A Global context for converting.
* @param fn AST of wasm fn.
* @param index Fn index, used for generating fn name
*/
fun fromFunc(ctx: ClsContext, fn: Node.Func, index: Int): Func {
ctx.debug { "Building function ${ctx.funcName(index)}" }
ctx.trace { "Function ast:\n${SExprToStr.fromSExpr(AstToSExpr.fromFunc(f))}" }
ctx.trace { "Function ast:\n${SExprToStr.fromSExpr(AstToSExpr.fromFunc(fn))}" }
var func = Func(
access = Opcodes.ACC_PRIVATE,
name = ctx.funcName(index),
params = f.type.params.map(Node.Type.Value::typeRef),
ret = f.type.ret?.let(Node.Type.Value::typeRef) ?: Void::class.ref
params = fn.type.params.map(Node.Type.Value::typeRef),
ret = fn.type.ret?.let(Node.Type.Value::typeRef) ?: Void::class.ref
)
// Rework the instructions
val reworkedInsns = ctx.reworker.rework(ctx, f)
val reworkedInsns = ctx.reworker.rework(ctx, fn)
// Start the implicit block
func = func.pushBlock(Node.Instr.Block(f.type.ret), f.type.ret, f.type.ret)
func = func.pushBlock(Node.Instr.Block(fn.type.ret), fn.type.ret, fn.type.ret)
// Create the context
val funcCtx = FuncContext(
cls = ctx,
node = f,
node = fn,
insns = reworkedInsns,
memIsLocalVar =
ctx.reworker.nonAdjacentMemAccesses(reworkedInsns) >= ctx.nonAdjacentMemAccessesRequiringLocalVar
@ -46,9 +54,9 @@ open class FuncBuilder {
// Add all instructions
ctx.debug { "Applying insns for function ${ctx.funcName(index)}" }
// All functions have an implicit block
func = funcCtx.insns.foldIndexed(func) { index, func, insn ->
func = funcCtx.insns.foldIndexed(func) { idx, f, insn ->
ctx.debug { "Applying insn $insn" }
val ret = applyInsn(funcCtx, func, insn, index)
val ret = applyInsn(funcCtx, f, insn, idx)
ctx.trace { "Resulting stack: ${ret.stack}"}
ret
}
@ -56,11 +64,11 @@ open class FuncBuilder {
// End the implicit block
val implicitBlock = func.currentBlock
func = applyEnd(funcCtx, func)
f.type.ret?.typeRef?.also { func = func.popExpecting(it, implicitBlock) }
fn.type.ret?.typeRef?.also { func = func.popExpecting(it, implicitBlock) }
// If the last instruction does not terminate, add the expected return
if (func.insns.isEmpty() || !func.insns.last().isTerminating) {
func = func.addInsns(InsnNode(when (f.type.ret) {
func = func.addInsns(InsnNode(when (fn.type.ret) {
null -> Opcodes.RETURN
Node.Type.Value.I32 -> Opcodes.IRETURN
Node.Type.Value.I64 -> Opcodes.LRETURN
@ -72,8 +80,10 @@ open class FuncBuilder {
}
fun applyInsn(ctx: FuncContext, fn: Func, i: Insn, index: Int) = when (i) {
is Insn.Node ->
applyNodeInsn(ctx, fn, i.insn, index)
is Insn.ImportFuncRefNeededOnStack ->
// Func refs are method handle fields
fn.addInsns(
@ -81,6 +91,7 @@ open class FuncBuilder {
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.thisRef.asmName,
ctx.cls.funcName(i.index), MethodHandle::class.ref.asmDesc)
).push(MethodHandle::class.ref)
is Insn.ImportGlobalSetRefNeededOnStack ->
// Import setters are method handle fields
fn.addInsns(
@ -88,13 +99,17 @@ open class FuncBuilder {
FieldInsnNode(Opcodes.GETFIELD, ctx.cls.thisRef.asmName,
ctx.cls.importGlobalSetterFieldName(i.index), MethodHandle::class.ref.asmDesc)
).push(MethodHandle::class.ref)
is Insn.ThisNeededOnStack ->
// load a reference onto the stack from a local variable
fn.addInsns(VarInsnNode(Opcodes.ALOAD, 0)).push(ctx.cls.thisRef)
is Insn.MemNeededOnStack ->
putMemoryOnStack(ctx, fn)
}
fun applyNodeInsn(ctx: FuncContext, fn: Func, i: Node.Instr, index: Int) = when (i) {
is Node.Instr.Unreachable ->
fn.addInsns(UnsupportedOperationException::class.athrow("Unreachable")).markUnreachable()
is Node.Instr.Nop ->
@ -127,18 +142,16 @@ open class FuncBuilder {
fn.pop().let { (fn, popped) ->
fn.addInsns(InsnNode(if (popped.stackSize == 2) Opcodes.POP2 else Opcodes.POP))
}
is Node.Instr.Select ->
applySelectInsn(ctx, fn)
is Node.Instr.GetLocal ->
applyGetLocal(ctx, fn, i.index)
is Node.Instr.SetLocal ->
applySetLocal(ctx, fn, i.index)
is Node.Instr.TeeLocal ->
applyTeeLocal(ctx, fn, i.index)
is Node.Instr.GetGlobal ->
applyGetGlobal(ctx, fn, i.index)
is Node.Instr.SetGlobal ->
applySetGlobal(ctx, fn, i.index)
is Node.Instr.Select -> applySelectInsn(ctx, fn)
// Variable access
is Node.Instr.GetLocal -> applyGetLocal(ctx, fn, i.index)
is Node.Instr.SetLocal -> applySetLocal(ctx, fn, i.index)
is Node.Instr.TeeLocal -> applyTeeLocal(ctx, fn, i.index)
is Node.Instr.GetGlobal -> applyGetGlobal(ctx, fn, i.index)
is Node.Instr.SetGlobal -> applySetGlobal(ctx, fn, i.index)
// Memory operators
is Node.Instr.I32Load, is Node.Instr.I64Load, is Node.Instr.F32Load, is Node.Instr.F64Load,
is Node.Instr.I32Load8S, is Node.Instr.I32Load8U, is Node.Instr.I32Load16U, is Node.Instr.I32Load16S,
is Node.Instr.I64Load8S, is Node.Instr.I64Load8U, is Node.Instr.I64Load16U, is Node.Instr.I64Load16S,
@ -148,10 +161,10 @@ open class FuncBuilder {
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32 ->
applyStoreOp(ctx, fn, i as Node.Instr.Args.AlignOffset, index)
is Node.Instr.CurrentMemory ->
applyCurrentMemory(ctx, fn)
is Node.Instr.GrowMemory ->
applyGrowMemory(ctx, fn)
is Node.Instr.MemorySize ->
applyMemorySize(ctx, fn)
is Node.Instr.MemoryGrow ->
applyMemoryGrow(ctx, fn)
is Node.Instr.I32Const ->
fn.addInsns(i.value.const).push(Int::class.ref)
is Node.Instr.I64Const ->
@ -461,18 +474,18 @@ open class FuncBuilder {
fun popUntilStackSize(
ctx: FuncContext,
fn: Func,
func: Func,
block: Func.Block,
untilStackSize: Int,
keepLast: Boolean
): Func {
ctx.debug { "For block ${block.insn}, popping until stack size $untilStackSize, keeping last? $keepLast" }
// Just get the latest, don't actually pop...
val type = if (keepLast) fn.pop().second else null
return (0 until Math.max(0, fn.stack.size - untilStackSize)).fold(fn) { fn, _ ->
val type = if (keepLast) func.pop().second else null
return (0 until Math.max(0, func.stack.size - untilStackSize)).fold(func) { fn, _ ->
// Essentially swap and pop if they want to keep the latest
(if (type != null && fn.stack.size > 1) fn.stackSwap(block) else fn).let { fn ->
fn.pop(block).let { (fn, poppedType) ->
(if (type != null && fn.stack.size > 1) fn.stackSwap(block) else fn).let { f ->
f.pop(block).let { (fn, poppedType) ->
fn.addInsns(InsnNode(if (poppedType.stackSize == 2) Opcodes.POP2 else Opcodes.POP))
}
}
@ -1062,24 +1075,26 @@ open class FuncBuilder {
).push(Int::class.ref)
}
fun applyGrowMemory(ctx: FuncContext, fn: Func) =
fun applyMemoryGrow(ctx: FuncContext, fn: Func) =
// Grow mem is a special case where the memory ref is already pre-injected on
// the stack before this call. Result is an int.
ctx.cls.assertHasMemory().let {
ctx.cls.mem.growMemory(ctx, fn)
}
fun applyCurrentMemory(ctx: FuncContext, fn: Func) =
fun applyMemorySize(ctx: FuncContext, fn: Func) =
// Curr mem is not specially injected, so we have to put the memory on the
// stack since we need it
ctx.cls.assertHasMemory().let {
putMemoryOnStack(ctx, fn).let { fn -> ctx.cls.mem.currentMemory(ctx, fn) }
}
/**
* Store is a special case where the memory ref is already pre-injected on
* the stack before this call. But it can have a memory leftover on the stack
* so we pop it if we need to
*/
fun applyStoreOp(ctx: FuncContext, fn: Func, insn: Node.Instr.Args.AlignOffset, insnIndex: Int) =
// Store is a special case where the memory ref is already pre-injected on
// the stack before this call. But it can have a memory leftover on the stack
// so we pop it if we need to
ctx.cls.assertHasMemory().let {
ctx.cls.mem.storeOp(ctx, fn, insn).let { fn ->
// As a special case, if this leaves the mem on the stack
@ -1268,11 +1283,11 @@ open class FuncBuilder {
}
}
fun applyReturnInsn(ctx: FuncContext, fn: Func): Func {
// If the current stakc is unreachable, we consider that our block since it
fun applyReturnInsn(ctx: FuncContext, func: Func): Func {
// If the current stack is unreachable, we consider that our block since it
// will pop properly.
val block = if (fn.currentBlock.unreachable) fn.currentBlock else fn.blockStack.first()
popForBlockEscape(ctx, fn, block).let { fn ->
val block = if (func.currentBlock.unreachable) func.currentBlock else func.blockStack.first()
popForBlockEscape(ctx, func, block).let { fn ->
return when (ctx.node.type.ret) {
null ->
fn.addInsns(InsnNode(Opcodes.RETURN))
@ -1284,9 +1299,9 @@ open class FuncBuilder {
fn.popExpecting(Float::class.ref, block).addInsns(InsnNode(Opcodes.FRETURN))
Node.Type.Value.F64 ->
fn.popExpecting(Double::class.ref, block).addInsns(InsnNode(Opcodes.DRETURN))
}.let { fn ->
if (fn.stack.isNotEmpty()) throw CompileErr.UnusedStackOnReturn(fn.stack)
fn.markUnreachable()
}.let { it ->
if (it.stack.isNotEmpty()) throw CompileErr.UnusedStackOnReturn(it.stack)
it.markUnreachable()
}
}
}

View File

@ -3,6 +3,15 @@ package asmble.compile.jvm
import asmble.ast.Node
import asmble.util.Logger
/**
* Jvm context of execution a function.
*
* @param cls Class execution context
* @param node Ast of this function
* @param insns A list of instructions
* @param memIsLocalVar If true then function use only local variables and don't load
* and store from memory.
*/
data class FuncContext(
val cls: ClsContext,
val node: Node.Func,

View File

@ -2,6 +2,9 @@ package asmble.compile.jvm
import asmble.ast.Node
/**
* Does some special manipulations with instruction.
*/
open class InsnReworker {
fun rework(ctx: ClsContext, func: Node.Func): List<Insn> {
@ -29,6 +32,7 @@ open class InsnReworker {
// Note, while walking backwards up the insns to find set/tee, we do skip entire
// blocks/loops/if+else combined with "end"
var neededEagerLocalIndices = emptySet<Int>()
fun addEagerSetIfNeeded(getInsnIndex: Int, localIndex: Int) {
// Within the param range? nothing needed
if (localIndex < func.type.params.size) return
@ -71,6 +75,7 @@ open class InsnReworker {
(insn is Insn.Node && insn.insn !is Node.Instr.SetLocal && insn.insn !is Node.Instr.TeeLocal)
if (needsEagerInit) neededEagerLocalIndices += localIndex
}
insns.forEachIndexed { index, insn ->
if (insn is Insn.Node && insn.insn is Node.Instr.GetLocal) addEagerSetIfNeeded(index, insn.insn.index)
}
@ -86,11 +91,18 @@ open class InsnReworker {
} + insns
}
/**
* Puts into instruction list needed instructions for pushing local variables
* into the stack and returns list of resulting instructions.
*
* @param ctx The Execution context
* @param insns The original instructions
*/
fun injectNeededStackVars(ctx: ClsContext, insns: List<Node.Instr>): List<Insn> {
ctx.trace { "Calculating places to inject needed stack variables" }
// How we do this:
// We run over each insn, and keep a running list of stack
// manips. If there is an insn that needs something so far back,
// manips. If there is an ins'n that needs something so far back,
// we calc where it needs to be added and keep a running list of
// insn inserts. Then at the end we settle up.
//
@ -109,6 +121,14 @@ open class InsnReworker {
// guarantee the value will be in the right order if there are
// multiple for the same index
var insnsToInject = emptyMap<Int, List<Insn>>()
/**
* This function inject current instruction in stack.
*
* @param insn The instruction to inject
* @param count Number of step back on the stack that should we do for
* finding injection index.
*/
fun injectBeforeLastStackCount(insn: Insn, count: Int) {
ctx.trace { "Injecting $insn back $count stack values" }
fun inject(index: Int) {
@ -130,27 +150,21 @@ open class InsnReworker {
// if we are at 0, add the result of said block if necessary to the count.
if (insideOfBlocks > 0) {
// If it's not a block, just ignore it
val blockStackDiff = insns[insnIndex].let {
when (it) {
is Node.Instr.Block -> if (it.type == null) 0 else 1
is Node.Instr.Loop -> 0
is Node.Instr.If -> if (it.type == null) -1 else 0
else -> null
}
}
if (blockStackDiff != null) {
(insns[insnIndex] as? Node.Instr.Args.Type)?.let {
insideOfBlocks--
ctx.trace { "Found block begin, number of blocks we're still inside: $insideOfBlocks" }
// We're back on our block, change the count
if (insideOfBlocks == 0) countSoFar += blockStackDiff
// We're back on our block, change the count if it had a result
if (insideOfBlocks == 0 && it.type != null) countSoFar++
}
if (insideOfBlocks > 0) continue
}
countSoFar += amountChanged
if (!foundUnconditionalJump) foundUnconditionalJump = insns[insnIndex].let { insn ->
insn is Node.Instr.Br || insn is Node.Instr.BrTable ||
insn is Node.Instr.Unreachable || insn is Node.Instr.Return
if (!foundUnconditionalJump) {
foundUnconditionalJump = insns[insnIndex].let { insn ->
insn is Node.Instr.Br || insn is Node.Instr.BrTable ||
insn is Node.Instr.Unreachable || insn is Node.Instr.Return
}
}
if (countSoFar == count) {
ctx.trace { "Found injection point as before insn #$insnIndex" }
@ -161,6 +175,7 @@ open class InsnReworker {
if (!foundUnconditionalJump) throw CompileErr.StackInjectionMismatch(count, insn)
}
var traceStackSize = 0 // Used only for trace
// Go over each insn, determining where to inject
insns.forEachIndexed { index, insn ->
// Handle special injection cases
@ -194,29 +209,41 @@ open class InsnReworker {
is Node.Instr.I64Store32 ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 2)
// Grow memory requires "mem" before the single param
is Node.Instr.GrowMemory ->
is Node.Instr.MemoryGrow ->
injectBeforeLastStackCount(Insn.MemNeededOnStack, 1)
else -> { }
}
// Log some trace output
ctx.trace {
insnStackDiff(ctx, insn).let {
traceStackSize += it
"Stack diff is $it for insn #$index $insn, stack size now: $traceStackSize"
}
}
// Add the current diff
ctx.trace { "Stack diff is ${insnStackDiff(ctx, insn)} for insn #$index $insn" }
stackManips += insnStackDiff(ctx, insn) to index
}
// Build resulting list
return insns.foldIndexed(emptyList<Insn>()) { index, ret, insn ->
return insns.foldIndexed(emptyList()) { index, ret, insn ->
val injections = insnsToInject[index] ?: emptyList()
ret + injections + Insn.Node(insn)
}
}
fun insnStackDiff(ctx: ClsContext, insn: Node.Instr) = when (insn) {
/**
* Calculate stack difference after calling instruction current instruction.
* Returns the difference from stack cursor position before instruction and after.
* `N = PUSH_OPS - POP_OPS.` '-n' mean that POP operation will be more than PUSH.
* If '0' then stack won't changed.
*/
fun insnStackDiff(ctx: ClsContext, insn: Node.Instr): Int = when (insn) {
is Node.Instr.Unreachable, is Node.Instr.Nop, is Node.Instr.Block,
is Node.Instr.Loop, is Node.Instr.If, is Node.Instr.Else,
is Node.Instr.End, is Node.Instr.Br, is Node.Instr.BrIf,
is Node.Instr.Loop, is Node.Instr.Else, is Node.Instr.End, is Node.Instr.Br,
is Node.Instr.Return -> NOP
is Node.Instr.BrTable -> POP_PARAM
is Node.Instr.If, is Node.Instr.BrIf, is Node.Instr.BrTable -> POP_PARAM
is Node.Instr.Call -> ctx.funcTypeAtIndex(insn.index).let {
// All calls pop params and any return is a push
(POP_PARAM * it.params.size) + (if (it.ret == null) NOP else PUSH_RESULT)
@ -238,9 +265,9 @@ open class InsnReworker {
is Node.Instr.I64Load32S, is Node.Instr.I64Load32U -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32 -> POP_PARAM
is Node.Instr.CurrentMemory -> PUSH_RESULT
is Node.Instr.GrowMemory -> POP_PARAM + PUSH_RESULT
is Node.Instr.I64Store32 -> POP_PARAM + POP_PARAM
is Node.Instr.MemorySize -> PUSH_RESULT
is Node.Instr.MemoryGrow -> POP_PARAM + PUSH_RESULT
is Node.Instr.I32Const, is Node.Instr.I64Const,
is Node.Instr.F32Const, is Node.Instr.F64Const -> PUSH_RESULT
is Node.Instr.I32Add, is Node.Instr.I32Sub, is Node.Instr.I32Mul, is Node.Instr.I32DivS,
@ -284,16 +311,17 @@ open class InsnReworker {
is Node.Instr.F64ReinterpretI64 -> POP_PARAM + PUSH_RESULT
}
fun nonAdjacentMemAccesses(insns: List<Insn>) = insns.fold(0 to false) { (count, lastCouldHaveMem), insn ->
/** Returns number of memory accesses. */
fun nonAdjacentMemAccesses(insns: List<Insn>): Int = insns.fold(0 to false) { (count, lastCouldHaveMem), insn ->
val inc =
if (lastCouldHaveMem) 0
else if (insn == Insn.MemNeededOnStack) 1
else if (insn is Insn.Node && insn.insn is Node.Instr.CurrentMemory) 1
else if (insn is Insn.Node && insn.insn is Node.Instr.MemorySize) 1
else 0
val couldSetMemNext = if (insn !is Insn.Node) false else when (insn.insn) {
is Node.Instr.I32Store, is Node.Instr.I64Store, is Node.Instr.F32Store, is Node.Instr.F64Store,
is Node.Instr.I32Store8, is Node.Instr.I32Store16, is Node.Instr.I64Store8, is Node.Instr.I64Store16,
is Node.Instr.I64Store32, is Node.Instr.GrowMemory -> true
is Node.Instr.I64Store32, is Node.Instr.MemoryGrow -> true
else -> false
}
(count + inc) to couldSetMemNext

View File

@ -201,7 +201,7 @@ open class Linker {
"instance" + mod.name.javaIdent.capitalize(), mod.ref.asmDesc),
InsnNode(Opcodes.ARETURN)
)
ctx.cls.methods.plusAssign(func)
ctx.cls.methods.plusAssign(func.toMethodNode())
}
class ModuleClass(val cls: Class<*>, overrideName: String? = null) {

View File

@ -0,0 +1,122 @@
package asmble.compile.jvm
import java.nio.ByteBuffer
import java.nio.ByteOrder
/**
* The default implementation of MemoryBuffer that based on java.nio.ByteBuffer
*/
open class MemoryByteBuffer(val bb: ByteBuffer) : MemoryBuffer() {
override fun put(arr: ByteArray): MemoryBuffer {
bb.put(arr)
return this
}
override fun clear(): MemoryBuffer {
bb.clear()
return this
}
override fun get(arr: ByteArray): MemoryBuffer {
bb.get(arr)
return this
}
override fun putLong(index: Int, n: Long): MemoryBuffer {
bb.putLong(index, n)
return this
}
override fun putDouble(index: Int, n: Double): MemoryBuffer {
bb.putDouble(index, n)
return this
}
override fun putShort(index: Int, n: Short): MemoryBuffer {
bb.putShort(index, n)
return this
}
override fun putFloat(index: Int, n: Float): MemoryBuffer {
bb.putFloat(index, n)
return this
}
override fun put(index: Int, b: Byte): MemoryBuffer {
bb.put(index, b)
return this
}
override fun putInt(index: Int, n: Int): MemoryBuffer {
bb.putInt(index, n)
return this
}
override fun capacity(): Int {
return bb.capacity()
}
override fun limit(): Int {
return bb.limit()
}
override fun limit(newLimit: Int): MemoryBuffer {
bb.limit(newLimit)
return this
}
override fun position(newPosition: Int): MemoryBuffer {
bb.position(newPosition)
return this
}
override fun order(order: ByteOrder): MemoryBuffer {
bb.order(order)
return this
}
override fun duplicate(): MemoryBuffer {
return MemoryByteBuffer(bb.duplicate())
}
override fun put(arr: ByteArray, offset: Int, length: Int): MemoryBuffer {
bb.put(arr, offset, length)
return this
}
override fun getInt(index: Int): Int {
return bb.getInt(index)
}
override fun get(index: Int): Byte {
return bb.get(index)
}
override fun getLong(index: Int): Long {
return bb.getLong(index)
}
override fun getShort(index: Int): Short {
return bb.getShort(index)
}
override fun getFloat(index: Int): Float {
return bb.getFloat(index)
}
override fun getDouble(index: Int): Double {
return bb.getDouble(index)
}
override fun equals(other: Any?): Boolean {
if (this === other)
return true
if (other !is MemoryByteBuffer)
return false
return bb == other.bb
}
override fun hashCode(): Int {
return bb.hashCode()
}
}

View File

@ -2,14 +2,25 @@ package asmble.compile.jvm
import org.objectweb.asm.Type
/**
* A Java field or method type. This class can be used to make it easier to
* manipulate type and method descriptors.
*
* @param asm Wrapped [org.objectweb.asm.Type] from asm library
*/
data class TypeRef(val asm: Type) {
/** The internal name of the class corresponding to this object or array type. */
val asmName: String get() = asm.internalName
/** The descriptor corresponding to this Java type. */
val asmDesc: String get() = asm.descriptor
fun asMethodRetDesc(vararg args: TypeRef) = Type.getMethodDescriptor(asm, *args.map { it.asm }.toTypedArray())
/** Size of this type in stack, either 1 or 2 only allowed, where 1 = 2^32` bits */
val stackSize: Int get() = if (asm == Type.DOUBLE_TYPE || asm == Type.LONG_TYPE) 2 else 1
fun asMethodRetDesc(vararg args: TypeRef) = Type.getMethodDescriptor(asm, *args.map { it.asm }.toTypedArray())
fun equivalentTo(other: TypeRef) = this == other || this == Unknown || other == Unknown
object UnknownType

View File

@ -5,6 +5,7 @@ import asmble.util.toRawIntBits
import asmble.util.toRawLongBits
import asmble.util.toUnsignedBigInt
import asmble.util.toUnsignedLong
import java.io.ByteArrayOutputStream
open class AstToBinary(val version: Long = 1L) {
@ -140,6 +141,9 @@ open class AstToBinary(val version: Long = 1L) {
fromResizableLimits(b, n.limits)
}
fun fromModule(n: Node.Module) =
ByteArrayOutputStream().also { fromModule(ByteWriter.OutputStream(it), n) }.toByteArray()
fun fromModule(b: ByteWriter, n: Node.Module) {
b.writeUInt32(0x6d736100)
b.writeUInt32(version)
@ -160,10 +164,33 @@ open class AstToBinary(val version: Long = 1L) {
wrapListSection(b, n, 9, n.elems, this::fromElem)
wrapListSection(b, n, 10, n.funcs, this::fromFuncBody)
wrapListSection(b, n, 11, n.data, this::fromData)
n.names?.also { fromNames(b, it) }
// All other custom sections after the previous
n.customSections.filter { it.afterSectionId > 11 }.forEach { fromCustomSection(b, it) }
}
fun fromNames(b: ByteWriter, n: Node.NameSection) {
fun <T> indexMap(b: ByteWriter, map: Map<Int, T>, fn: (T) -> Unit) {
b.writeVarUInt32(map.size)
map.forEach { index, v -> b.writeVarUInt32(index).also { fn(v) } }
}
fun nameMap(b: ByteWriter, map: Map<Int, String>) = indexMap(b, map) { b.writeString(it) }
b.writeVarUInt7(0)
b.withVarUInt32PayloadSizePrepended { b ->
b.writeString("name")
n.moduleName?.also { moduleName ->
b.writeVarUInt7(0)
b.withVarUInt32PayloadSizePrepended { b -> b.writeString(moduleName) }
}
if (n.funcNames.isNotEmpty()) b.writeVarUInt7(1).also {
b.withVarUInt32PayloadSizePrepended { b -> nameMap(b, n.funcNames) }
}
if (n.localNames.isNotEmpty()) b.writeVarUInt7(2).also {
b.withVarUInt32PayloadSizePrepended { b -> indexMap(b, n.localNames) { nameMap(b, it) } }
}
}
}
fun fromResizableLimits(b: ByteWriter, n: Node.ResizableLimits) {
b.writeVarUInt1(n.maximum != null)
b.writeVarUInt32(n.initial)

View File

@ -62,13 +62,21 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
Node.ExternalKind.GLOBAL -> newMulti("global") + v.index
}
fun fromFunc(v: Node.Func, name: String? = null, impExp: ImportOrExport? = null) =
newMulti("func", name) + impExp?.let(this::fromImportOrExport) + fromFuncSig(v.type) +
fromLocals(v.locals) + fromInstrs(v.instructions).unwrapInstrs()
fun fromFunc(
v: Node.Func,
name: String? = null,
impExp: ImportOrExport? = null,
localNames: Map<Int, String> = emptyMap()
) =
newMulti("func", name) + impExp?.let(this::fromImportOrExport) + fromFuncSig(v.type, localNames) +
fromLocals(v.locals, v.type.params.size, localNames) + fromInstrs(v.instructions).unwrapInstrs()
fun fromFuncSig(v: Node.Type.Func): List<SExpr> {
fun fromFuncSig(v: Node.Type.Func, localNames: Map<Int, String> = emptyMap()): List<SExpr> {
var ret = emptyList<SExpr>()
if (v.params.isNotEmpty()) ret += newMulti("param") + v.params.map(this::fromType)
if (v.params.isNotEmpty()) {
if (localNames.isEmpty()) ret += newMulti("param") + v.params.map(this::fromType)
else ret += v.params.mapIndexed { index, param -> newMulti("param", localNames[index]) + fromType(param) }
}
v.ret?.also { ret += newMulti("result") + fromType(it) }
return ret
}
@ -80,8 +88,8 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
fun fromGlobalSig(v: Node.Type.Global) =
if (v.mutable) newMulti("mut") + fromType(v.contentType) else fromType(v.contentType)
fun fromImport(v: Node.Import, types: List<Node.Type.Func>) =
(newMulti("import") + v.module.quoted) + v.field.quoted + fromImportKind(v.kind, types)
fun fromImport(v: Node.Import, types: List<Node.Type.Func>, name: String? = null) =
(newMulti("import") + v.module.quoted) + v.field.quoted + fromImportKind(v.kind, types, name)
fun fromImportFunc(v: Node.Import.Kind.Func, types: List<Node.Type.Func>, name: String? = null) =
fromImportFunc(types.getOrElse(v.typeIndex) { throw Exception("No type at ${v.typeIndex}") }, name)
@ -91,11 +99,11 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
fun fromImportGlobal(v: Node.Import.Kind.Global, name: String? = null) =
newMulti("global", name) + fromGlobalSig(v.type)
fun fromImportKind(v: Node.Import.Kind, types: List<Node.Type.Func>) = when(v) {
is Node.Import.Kind.Func -> fromImportFunc(v, types)
is Node.Import.Kind.Table -> fromImportTable(v)
is Node.Import.Kind.Memory -> fromImportMemory(v)
is Node.Import.Kind.Global -> fromImportGlobal(v)
fun fromImportKind(v: Node.Import.Kind, types: List<Node.Type.Func>, name: String? = null) = when(v) {
is Node.Import.Kind.Func -> fromImportFunc(v, types, name)
is Node.Import.Kind.Table -> fromImportTable(v, name)
is Node.Import.Kind.Memory -> fromImportMemory(v, name)
is Node.Import.Kind.Global -> fromImportGlobal(v, name)
}
fun fromImportMemory(v: Node.Import.Kind.Memory, name: String? = null) =
@ -161,8 +169,10 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
return listOf(SExpr.Multi(untilNext().first))
}
fun fromLocals(v: List<Node.Type.Value>) =
if (v.isEmpty()) null else newMulti("local") + v.map(this::fromType)
fun fromLocals(v: List<Node.Type.Value>, paramOffset: Int, localNames: Map<Int, String> = emptyMap()) =
if (v.isEmpty()) emptyList()
else if (localNames.isEmpty()) listOf(newMulti("local") + v.map(this::fromType))
else v.mapIndexed { index, v -> newMulti("local", localNames[paramOffset + index]) + fromType(v) }
fun fromMemory(v: Node.Type.Memory, name: String? = null, impExp: ImportOrExport? = null) =
newMulti("memory", name) + impExp?.let(this::fromImportOrExport) + fromMemorySig(v)
@ -175,7 +185,7 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
is Script.Cmd.Meta.Output -> newMulti("output", v.name) + v.str
}
fun fromModule(v: Node.Module, name: String? = null): SExpr.Multi {
fun fromModule(v: Node.Module, name: String? = v.names?.moduleName): SExpr.Multi {
var ret = newMulti("module", name)
// If there is a call_indirect, then we need to output all types in exact order.
@ -187,8 +197,14 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
v.types.filterIndexed { i, _ -> importIndices.contains(i) } - v.funcs.map { it.type }
}
// Keep track of the current function index for names
var funcIndex = -1
ret += types.map { fromTypeDef(it) }
ret += v.imports.map { fromImport(it, v.types) }
ret += v.imports.map {
if (it.kind is Node.Import.Kind.Func) funcIndex++
fromImport(it, v.types, v.names?.funcNames?.get(funcIndex))
}
ret += v.exports.map(this::fromExport)
ret += v.tables.map { fromTable(it) }
ret += v.memories.map { fromMemory(it) }
@ -196,7 +212,14 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
ret += v.elems.map(this::fromElem)
ret += v.data.map(this::fromData)
ret += v.startFuncIndex?.let(this::fromStart)
ret += v.funcs.map { fromFunc(it) }
ret += v.funcs.map {
funcIndex++
fromFunc(
v = it,
name = v.names?.funcNames?.get(funcIndex),
localNames = v.names?.localNames?.get(funcIndex) ?: emptyMap()
)
}
return ret
}
@ -235,10 +258,8 @@ open class AstToSExpr(val parensInstrs: Boolean = true) {
if (exp == null) this else this.copy(vals = this.vals + exp)
private operator fun SExpr.Multi.plus(exps: List<SExpr>?) =
if (exps == null || exps.isEmpty()) this else this.copy(vals = this.vals + exps)
private fun newMulti(initSymb: String? = null, initName: String? = null): SExpr.Multi {
initName?.also { require(it.startsWith("$")) }
return SExpr.Multi() + initSymb + initName
}
private fun newMulti(initSymb: String? = null, initName: String? = null) =
SExpr.Multi() + initSymb + initName?.let { "$$it" }
private fun List<SExpr.Multi>.unwrapInstrs() =
if (parensInstrs) this else this.single().vals
private val String.quoted get() = fromString(this, true)

View File

@ -6,7 +6,8 @@ import java.nio.ByteBuffer
open class BinaryToAst(
val version: Long = 1L,
val logger: Logger = Logger.Print(Logger.Level.OFF)
val logger: Logger = Logger.Print(Logger.Level.WARN),
val includeNameSection: Boolean = true
) : Logger by logger {
fun toBlockType(b: ByteReader) = b.readVarInt7().toInt().let {
@ -19,6 +20,23 @@ open class BinaryToAst(
payload = b.readBytes()
)
fun toNameSection(b: ByteReader) = generateSequence {
if (b.isEof) null
else b.readVarUInt7().toInt() to b.read(b.readVarUInt32AsInt())
}.fold(Node.NameSection(null, emptyMap(), emptyMap())) { sect, (type, b) ->
fun <T> indexMap(b: ByteReader, fn: (ByteReader) -> T) =
b.readList { it.readVarUInt32AsInt() to fn(it) }.let { pairs ->
pairs.toMap().also { require(it.size == pairs.size) { "Malformed names: duplicate indices" } }
}
fun nameMap(b: ByteReader) = indexMap(b) { it.readString() }
when (type) {
0 -> sect.copy(moduleName = b.readString())
1 -> sect.copy(funcNames = nameMap(b))
2 -> sect.copy(localNames = indexMap(b, ::nameMap))
else -> error("Malformed names: unrecognized type: $type")
}.also { require(b.isEof) }
}
fun toData(b: ByteReader) = Node.Data(
index = b.readVarUInt32AsInt(),
offset = toInitExpr(b),
@ -144,12 +162,15 @@ open class BinaryToAst(
}
}
fun toLocals(b: ByteReader) = b.readVarUInt32AsInt().let { size ->
toValueType(b).let { type -> List(size) { type } }
fun toLocals(b: ByteReader): List<Node.Type.Value> {
val size = try { b.readVarUInt32AsInt() } catch (e: NumberFormatException) { throw IoErr.InvalidLocalSize(e) }
return toValueType(b).let { type -> List(size) { type } }
}
fun toMemoryType(b: ByteReader) = Node.Type.Memory(toResizableLimits(b))
fun toModule(b: ByteArray) = toModule(ByteReader.InputStream(b.inputStream()))
fun toModule(b: ByteReader): Node.Module {
if (b.readUInt32() != 0x6d736100L) throw IoErr.InvalidMagicNumber()
b.readUInt32().let { if (it != version) throw IoErr.InvalidVersion(it, listOf(version)) }
@ -164,14 +185,17 @@ open class BinaryToAst(
require(sectionId > maxSectionId) { "Section ID $sectionId came after $maxSectionId" }.
also { maxSectionId = sectionId }
val sectionLen = b.readVarUInt32AsInt()
// each 'read' invocation creates new InputStream and don't closes it
sections += sectionId to b.read(sectionLen)
}
// Now build the module
fun <T> readSectionList(sectionId: Int, fn: (ByteReader) -> T) =
sections.find { it.first == sectionId }?.second?.readList(fn) ?: emptyList()
val types = readSectionList(1, this::toFuncType)
val funcIndices = readSectionList(3) { it.readVarUInt32AsInt() }
var nameSection: Node.NameSection? = null
return Node.Module(
types = types,
imports = readSectionList(2, this::toImport),
@ -192,10 +216,18 @@ open class BinaryToAst(
val afterSectionId = if (index == 0) 0 else sections[index - 1].let { (prevSectionId, _) ->
if (prevSectionId == 0) customSections.last().afterSectionId else prevSectionId
}
customSections + toCustomSection(b, afterSectionId)
// Try to parse the name section
val section = toCustomSection(b, afterSectionId).takeIf { section ->
val shouldParseNames = includeNameSection && nameSection == null && section.name == "name"
!shouldParseNames || try {
nameSection = toNameSection(ByteReader.InputStream(section.payload.inputStream()))
false
} catch (e: Exception) { warn { "Failed parsing name section: $e" }; true }
}
if (section == null) customSections else customSections + section
}
}
)
).copy(names = nameSection)
}
fun toResizableLimits(b: ByteReader) = b.readVarUInt1().let {

View File

@ -1,12 +1,12 @@
package asmble.io
import asmble.util.toIntExact
import asmble.util.toUnsignedBigInt
import asmble.util.toUnsignedLong
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.math.BigInteger
abstract class ByteReader {
abstract val isEof: Boolean
@ -34,27 +34,30 @@ abstract class ByteReader {
}
fun readVarInt7() = readSignedLeb128().let {
require(it >= Byte.MIN_VALUE.toLong() && it <= Byte.MAX_VALUE.toLong())
if (it < Byte.MIN_VALUE.toLong() || it > Byte.MAX_VALUE.toLong()) throw IoErr.InvalidLeb128Number()
it.toByte()
}
fun readVarInt32() = readSignedLeb128().toIntExact()
fun readVarInt32() = readSignedLeb128().let {
if (it < Int.MIN_VALUE.toLong() || it > Int.MAX_VALUE.toLong()) throw IoErr.InvalidLeb128Number()
it.toInt()
}
fun readVarInt64() = readSignedLeb128()
fun readVarInt64() = readSignedLeb128(9)
fun readVarUInt1() = readUnsignedLeb128().let {
require(it == 1 || it == 0)
if (it != 1 && it != 0) throw IoErr.InvalidLeb128Number()
it == 1
}
fun readVarUInt7() = readUnsignedLeb128().let {
require(it <= 255)
if (it > 255) throw IoErr.InvalidLeb128Number()
it.toShort()
}
fun readVarUInt32() = readUnsignedLeb128().toUnsignedLong()
protected fun readUnsignedLeb128(): Int {
protected fun readUnsignedLeb128(maxCount: Int = 4): Int {
// Taken from Android source, Apache licensed
var result = 0
var cur: Int
@ -63,12 +66,12 @@ abstract class ByteReader {
cur = readByte().toInt() and 0xff
result = result or ((cur and 0x7f) shl (count * 7))
count++
} while (cur and 0x80 == 0x80 && count < 5)
if (cur and 0x80 == 0x80) throw NumberFormatException()
} while (cur and 0x80 == 0x80 && count <= maxCount)
if (cur and 0x80 == 0x80) throw IoErr.InvalidLeb128Number()
return result
}
private fun readSignedLeb128(): Long {
private fun readSignedLeb128(maxCount: Int = 4): Long {
// Taken from Android source, Apache licensed
var result = 0L
var cur: Int
@ -79,12 +82,25 @@ abstract class ByteReader {
result = result or ((cur and 0x7f).toLong() shl (count * 7))
signBits = signBits shl 7
count++
} while (cur and 0x80 == 0x80 && count < 10)
if (cur and 0x80 == 0x80) throw NumberFormatException()
} while (cur and 0x80 == 0x80 && count <= maxCount)
if (cur and 0x80 == 0x80) throw IoErr.InvalidLeb128Number()
// Check for 64 bit invalid, taken from Apache/MIT licensed:
// https://github.com/paritytech/parity-wasm/blob/2650fc14c458c6a252c9dc43dd8e0b14b6d264ff/src/elements/primitives.rs#L351
// TODO: probably need 32 bit checks too, but meh, not in the suite
if (count > maxCount && maxCount == 9) {
if (cur and 0b0100_0000 == 0b0100_0000) {
if ((cur or 0b1000_0000).toByte() != (-1).toByte()) throw IoErr.InvalidLeb128Number()
} else if (cur != 0) {
throw IoErr.InvalidLeb128Number()
}
}
if ((signBits shr 1) and result != 0L) result = result or signBits
return result
}
// todo looks like this InputStream isn't ever closed
class InputStream(val ins: java.io.InputStream) : ByteReader() {
private var nextByte: Byte? = null
private var sawEof = false

View File

@ -1,7 +1,6 @@
package asmble.io
import asmble.AsmErr
import java.math.BigInteger
sealed class IoErr(message: String, cause: Throwable? = null) : RuntimeException(message, cause), AsmErr {
class UnexpectedEnd : IoErr("Unexpected EOF") {
@ -119,4 +118,13 @@ sealed class IoErr(message: String, cause: Throwable? = null) : RuntimeException
class InvalidUtf8Encoding : IoErr("Some byte sequence was not UTF-8 compatible") {
override val asmErrString get() = "invalid UTF-8 encoding"
}
class InvalidLeb128Number : IoErr("Invalid LEB128 number") {
override val asmErrString get() = "integer representation too long"
override val asmErrStrings get() = listOf(asmErrString, "integer too large")
}
class InvalidLocalSize(cause: NumberFormatException) : IoErr("Invalid local size", cause) {
override val asmErrString get() = "too many locals"
}
}

View File

@ -9,22 +9,27 @@ import asmble.util.*
import java.io.ByteArrayInputStream
import java.math.BigInteger
typealias NameMap = Map<String, Int>
open class SExprToAst {
open class SExprToAst(
val includeNames: Boolean = true
) {
data class ExprContext(
val nameMap: NameMap,
val blockDepth: Int = 0,
val types: List<Node.Type.Func> = emptyList(),
val callIndirectNeverBeforeSeenFuncTypes: MutableList<Node.Type.Func> = mutableListOf()
)
) {
companion object {
val empty = ExprContext(NameMap(emptyMap(), null, null))
}
}
data class FuncResult(
val name: String?,
val func: Node.Func,
val importOrExport: ImportOrExport?,
// These come from call_indirect insns
val additionalFuncTypesToAdd: List<Node.Type.Func>
val additionalFuncTypesToAdd: List<Node.Type.Func>,
val nameMap: NameMap
)
fun toAction(exp: SExpr.Multi): Script.Cmd.Action {
@ -36,7 +41,7 @@ open class SExprToAst {
return when(exp.vals.first().symbolStr()) {
"invoke" ->
Script.Cmd.Action.Invoke(name, str, exp.vals.drop(index).map {
toExprMaybe(it as SExpr.Multi, ExprContext(emptyMap()))
toExprMaybe(it as SExpr.Multi, ExprContext.empty)
})
"get" ->
Script.Cmd.Action.Get(name, str)
@ -49,7 +54,7 @@ open class SExprToAst {
return when(exp.vals.first().symbolStr()) {
"assert_return" ->
Script.Cmd.Assertion.Return(toAction(mult),
exp.vals.drop(2).map { toExprMaybe(it as SExpr.Multi, ExprContext(emptyMap())) })
exp.vals.drop(2).map { toExprMaybe(it as SExpr.Multi, ExprContext.empty) })
"assert_return_canonical_nan" ->
Script.Cmd.Assertion.ReturnNan(toAction(mult), canonical = true)
"assert_return_arithmetic_nan" ->
@ -176,7 +181,7 @@ open class SExprToAst {
var innerCtx = ctx.copy(blockDepth = ctx.blockDepth + 1)
exp.maybeName(opOffset)?.also {
opOffset++
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap + ("block:$it" to innerCtx.blockDepth))
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap.add("block", it, innerCtx.blockDepth))
}
val sigs = toBlockSigMaybe(exp, opOffset)
@ -233,7 +238,7 @@ open class SExprToAst {
var (nameMap, exprsUsed, sig) = toFuncSig(exp, currentIndex, origNameMap, types)
currentIndex += exprsUsed
val locals = exp.repeated("local", currentIndex, { toLocals(it) }).mapIndexed { index, (nameMaybe, vals) ->
nameMaybe?.also { require(vals.size == 1); nameMap += "local:$it" to (index + sig.params.size) }
nameMaybe?.also { require(vals.size == 1); nameMap = nameMap.add("local", it, index + sig.params.size) }
vals
}
currentIndex += locals.size
@ -250,7 +255,8 @@ open class SExprToAst {
name = name,
func = Node.Func(sig, locals.flatten(), instrs),
importOrExport = maybeImpExp,
additionalFuncTypesToAdd = ctx.callIndirectNeverBeforeSeenFuncTypes
additionalFuncTypesToAdd = ctx.callIndirectNeverBeforeSeenFuncTypes,
nameMap = nameMap
)
}
@ -268,7 +274,7 @@ open class SExprToAst {
} else null to offset
var nameMap = origNameMap
val params = exp.repeated("param", offset, { toParams(it) }).mapIndexed { index, (nameMaybe, vals) ->
nameMaybe?.also { require(vals.size == 1); nameMap += "local:$it" to index }
nameMaybe?.also { require(vals.size == 1); nameMap = nameMap.add("local", it, index) }
vals
}
val resultExps = exp.repeated("result", offset + params.size, this::toResult)
@ -395,7 +401,7 @@ open class SExprToAst {
val maybeName = exp.maybeName(offset + opOffset)
if (maybeName != null) {
opOffset++
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap + ("block:$maybeName" to innerCtx.blockDepth))
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap.add("block", maybeName, innerCtx.blockDepth))
}
val sigs = toBlockSigMaybe(exp, offset + opOffset)
opOffset += sigs.size
@ -428,7 +434,7 @@ open class SExprToAst {
opOffset++
exp.maybeName(offset + opOffset)?.also {
opOffset++
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap + ("block:$it" to ctx.blockDepth))
innerCtx = innerCtx.copy(nameMap = innerCtx.nameMap.add("block", it, ctx.blockDepth))
}
toInstrs(exp, offset + opOffset, innerCtx, false).also {
ret += it.first
@ -522,7 +528,7 @@ open class SExprToAst {
val exps = exp.vals.mapNotNull { it as? SExpr.Multi }
// Eagerly build the names (for forward decls)
val (nameMap, eagerTypes) = toModuleForwardNameMapAndTypes(exps)
var (nameMap, eagerTypes) = toModuleForwardNameMapAndTypes(exps)
mod = mod.copy(types = eagerTypes)
fun Node.Module.addTypeIfNotPresent(type: Node.Type.Func): Pair<Node.Module, Int> {
@ -555,6 +561,7 @@ open class SExprToAst {
Node.Import.Kind.Memory(kind) to (memoryCount++ to Node.ExternalKind.MEMORY)
else -> throw Exception("Unrecognized import kind: $kind")
}
mod = mod.copy(
imports = mod.imports + Node.Import(module, field, importKind),
exports = mod.exports + exportFields.map {
@ -579,11 +586,14 @@ open class SExprToAst {
"elem" -> mod = mod.copy(elems = mod.elems + toElem(exp, nameMap))
"data" -> mod = mod.copy(data = mod.data + toData(exp, nameMap))
"start" -> mod = mod.copy(startFuncIndex = toStart(exp, nameMap))
"func" -> toFunc(exp, nameMap, mod.types).also { (_, fn, impExp, additionalFuncTypes) ->
"func" -> toFunc(exp, nameMap, mod.types).also { (_, fn, impExp, additionalFuncTypes, localNameMap) ->
if (impExp is ImportOrExport.Import) {
handleImport(impExp.module, impExp.name, fn.type, impExp.exportFields)
} else {
if (impExp is ImportOrExport.Export) addExport(impExp, Node.ExternalKind.FUNCTION, funcCount)
if (includeNames) nameMap = nameMap.copy(
localNames = nameMap.localNames!! + (funcCount to localNameMap.getAllNamesByIndex("local"))
)
funcCount++
mod = mod.copy(funcs = mod.funcs + fn).addTypeIfNotPresent(fn.type).first
mod = additionalFuncTypes.fold(mod) { mod, typ -> mod.addTypeIfNotPresent(typ).first }
@ -644,6 +654,15 @@ open class SExprToAst {
if (mod.tables.size + mod.imports.count { it.kind is Node.Import.Kind.Table } > 1)
throw IoErr.MultipleTables()
// Set the name map pieces if we're including them
if (includeNames) mod = mod.copy(
names = Node.NameSection(
moduleName = name,
funcNames = nameMap.funcNames!!,
localNames = nameMap.localNames!!
)
)
return name to mod
}
@ -680,10 +699,14 @@ open class SExprToAst {
var globalCount = 0
var tableCount = 0
var memoryCount = 0
var namesToIndices = emptyMap<String, Int>()
var nameMap = NameMap(
names = emptyMap(),
funcNames = if (includeNames) emptyMap() else null,
localNames = if (includeNames) emptyMap() else null
)
var types = emptyList<Node.Type.Func>()
fun maybeAddName(name: String?, index: Int, type: String) {
name?.let { namesToIndices += "$type:$it" to index }
name?.also { nameMap = nameMap.add(type, it, index) }
}
// All imports first
@ -711,12 +734,12 @@ open class SExprToAst {
"memory" -> maybeAddName(kindName, memoryCount++, "memory")
// We go ahead and do the full type def build here eagerly
"type" -> maybeAddName(kindName, types.size, "type").also { _ ->
toTypeDef(it, namesToIndices).also { (_, type) -> types += type }
toTypeDef(it, nameMap).also { (_, type) -> types += type }
}
else -> {}
}
}
return namesToIndices to types
return nameMap to types
}
fun toOpMaybe(exp: SExpr.Multi, offset: Int, ctx: ExprContext): Pair<Node.Instr, Int>? {
@ -753,7 +776,8 @@ open class SExprToAst {
// First lookup the func sig
val (updatedNameMap, expsUsed, funcType) = toFuncSig(exp, offset + 1, ctx.nameMap, ctx.types)
// Make sure there are no changes to the name map
if (ctx.nameMap.size != updatedNameMap.size) throw IoErr.IndirectCallSetParamNames()
if (ctx.nameMap.size != updatedNameMap.size)
throw IoErr.IndirectCallSetParamNames()
// Obtain the func index from the types table, the indirects table, or just add it
var funcTypeIndex = ctx.types.indexOf(funcType)
// If it's not in the type list, check the call indirect list
@ -910,7 +934,7 @@ open class SExprToAst {
fun toVarMaybe(exp: SExpr, nameMap: NameMap, nameType: String): Int? {
return exp.symbolStr()?.let { it ->
if (it.startsWith("$"))
nameMap["$nameType:$it"] ?:
nameMap.get(nameType, it.drop(1)) ?:
throw Exception("Unable to find index for name $it of type $nameType in $nameMap")
else if (it.startsWith("0x")) it.substring(2).toIntOrNull(16)
else it.toIntOrNull()
@ -1005,7 +1029,7 @@ open class SExprToAst {
private fun SExpr.Multi.maybeName(index: Int): String? {
if (this.vals.size > index && this.vals[index] is SExpr.Symbol) {
val sym = this.vals[index] as SExpr.Symbol
if (!sym.quoted && sym.contents[0] == '$') return sym.contents
if (!sym.quoted && sym.contents[0] == '$') return sym.contents.drop(1)
}
return null
}
@ -1028,5 +1052,26 @@ open class SExprToAst {
return this.vals.first().requireSymbol(contents, quotedCheck)
}
data class NameMap(
// Key prefixed with type then colon before actual name
val names: Map<String, Int>,
// Null if not including names
val funcNames: Map<Int, String>?,
val localNames: Map<Int, Map<Int, String>>?
) {
val size get() = names.size
fun add(type: String, name: String, index: Int) = copy(
names = names + ("$type:$name" to index),
funcNames = funcNames?.let { if (type == "func") it + (index to name) else it }
)
fun get(type: String, name: String) = names["$type:$name"]
fun getAllNamesByIndex(type: String) = names.mapNotNull { (k, v) ->
k.takeIf { k.startsWith("$type:") }?.let { v to k.substring(type.length + 1) }
}.toMap()
}
companion object : SExprToAst()
}

View File

@ -18,6 +18,13 @@ open class StrToSExpr {
data class Error(val pos: Pos, val msg: String) : ParseResult()
}
fun parseSingleMulti(str: CharSequence) = parse(str).let {
when (it) {
is ParseResult.Success -> (it.vals.singleOrNull() as? SExpr.Multi) ?: error("Not a single multi-expr")
is ParseResult.Error -> error("Failed parsing at ${it.pos.line}:${it.pos.char} - ${it.msg}")
}
}
fun parse(str: CharSequence): ParseResult {
val state = ParseState(str)
val ret = mutableListOf<SExpr>()

View File

@ -0,0 +1,56 @@
package asmble.run.jvm
/**
* Used to tack the state of the environment module.
*/
data class EnvState(
var spentGas: Long = 0,
// executed instruction counter
var EIC: Long = 0
)
/**
* Module used for gas and EIC metering.
*/
open class EnvModule(private val gasLimit: Long) {
private var state = EnvState();
/**
* [Wasm function]
* Adds spent gas to overall spent gas and checks limit exceeding.
*/
fun gas(spentGas: Int) {
if(state.spentGas + spentGas > gasLimit) {
// TODO : check for overflow, throw an exception
}
state.spentGas += spentGas;
}
/**
* [Wasm function]
* Adds EIC to overall executed instruction counter.
*/
fun eic(EIC: Int) {
state.EIC += EIC;
}
/**
* Sets spent gas and EIC value to 0. Used from WasmVm to clear gas value before metering.
* It should be impossible to call this function from a Wasm module.
*/
fun clearState() {
state.spentGas = 0;
state.EIC = 0;
}
/**
* Returns environment module state.
* Used from WasmVm to determine spent gas and executed instruction counter after each invocation.
*/
fun getState(): EnvState {
return state;
}
}

View File

@ -0,0 +1,45 @@
package asmble.run.jvm
import asmble.compile.jvm.Mem
import java.io.PrintWriter
import java.nio.ByteBuffer
/**
* Module used for logging UTF-8 strings from a Wasm module to a given writer.
*/
open class LoggerModule(val writer: PrintWriter) {
// one memory page is quite enough for save temporary buffer
private val memoryPages = 1
private val memory =
ByteBuffer.allocate(memoryPages * Mem.PAGE_SIZE) as ByteBuffer
/**
* [Wasm function]
* Writes one byte to the logger memory buffer. If there is no place flushes
* all data from the buffer to [PrintWriter] and try to put the byte again.
*/
fun write(byte: Int) {
val isFull = memory.position() >= memory.limit()
if (isFull) {
flush()
}
memory.put(byte.toByte())
}
/**
* [Wasm function]
* Reads all bytes from the logger memory buffer, convert its to UTF-8
* string and writes to stdout.
* Cleans the logger memory buffer.
*/
fun flush() {
val message = String(memory.array(), 0, memory.position())
writer.print(message)
writer.flush()
memory.clear()
}
}

View File

@ -76,6 +76,8 @@ interface Module {
// If there is a memory import, we have to get the one with the mem class as the first
val memImport = mod.imports.find { it.kind is Node.Import.Kind.Memory }
val builder = ctx.memoryBuilder
val memLimit = if (memImport != null) {
constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull()?.ref == mem.memType }
val memImportKind = memImport.kind as Node.Import.Kind.Memory
@ -89,6 +91,13 @@ interface Module {
throw RunErr.ImportMemoryCapacityTooLarge(it * Mem.PAGE_SIZE, memCap)
}
memLimit
} else if (builder != null) {
constructor = cls.declaredConstructors.find { it.parameterTypes.firstOrNull()?.ref == mem.memType }
val memLimit = ctx.defaultMaxMemPages * Mem.PAGE_SIZE
val memInst = builder.build(memLimit)
constructorParams += memInst
memLimit
} else {
// Find the constructor with no max mem amount (i.e. not int and not memory)
constructor = cls.declaredConstructors.find {
@ -116,9 +125,9 @@ interface Module {
}
// Global imports
val globalImports = mod.imports.mapNotNull {
if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobal(it, it.kind.type)
else null
val globalImports = mod.imports.flatMap {
if (it.kind is Node.Import.Kind.Global) ctx.resolveImportGlobals(it, it.kind.type)
else emptyList()
}
constructorParams += globalImports

View File

@ -55,4 +55,12 @@ sealed class RunErr(message: String, cause: Throwable? = null) : RuntimeExceptio
override val asmErrString get() = "unknown import"
override val asmErrStrings get() = listOf(asmErrString, "incompatible import type")
}
class ImportGlobalInvalidMutability(
val module: String,
val field: String,
val expected: Boolean
) : RunErr("Expected imported global $module::$field to have mutability as ${!expected}") {
override val asmErrString get() = "incompatible import type"
}
}

View File

@ -19,6 +19,20 @@ import java.lang.invoke.MethodType
import java.lang.reflect.InvocationTargetException
import java.util.*
/**
* Script context. Contains all the information needed to execute this script.
*
* @param packageName Package name for this script
* @param modules List of all modules to be load for this script
* @param registrations Registered modules, key - module name, value - module instance
* @param logger Logger for this script
* @param adjustContext Fn for tuning context (looks redundant)
* @param classLoader ClassLoader for loading all classes for this script
* @param exceptionTranslator Converts exceptions to error messages
* @param defaultMaxMemPages The maximum number of memory pages when a module doesn't say
* @param includeBinaryInCompiledClass Store binary wasm code to compiled class
* file as annotation [asmble.annotation.WasmModule]
*/
data class ScriptContext(
val packageName: String,
val modules: List<Module.Compiled> = emptyList(),
@ -29,8 +43,10 @@ data class ScriptContext(
ScriptContext.SimpleClassLoader(ScriptContext::class.java.classLoader, logger),
val exceptionTranslator: ExceptionTranslator = ExceptionTranslator,
val defaultMaxMemPages: Int = 1,
val includeBinaryInCompiledClass: Boolean = false
val includeBinaryInCompiledClass: Boolean = false,
val memoryBuilder: MemoryBufferBuilder? = null
) : Logger by logger {
fun withHarnessRegistered(out: PrintWriter = PrintWriter(System.out, true)) =
withModuleRegistered("spectest", Module.Native(TestHarness(out)))
@ -263,10 +279,12 @@ data class ScriptContext(
return Module.Compiled(mod, classLoader.fromBuiltContext(ctx), name, ctx.mem)
}
fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType): MethodHandle {
fun bindImport(import: Node.Import, getter: Boolean, methodType: MethodType) = bindImport(
import, if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent, methodType)
fun bindImport(import: Node.Import, javaName: String, methodType: MethodType): MethodHandle {
// Find a method that matches our expectations
val module = registrations[import.module] ?: throw RunErr.ImportNotFound(import.module, import.field)
val javaName = if (getter) "get" + import.field.javaIdent.capitalize() else import.field.javaIdent
val kind = when (import.kind) {
is Node.Import.Kind.Func -> WasmExternalKind.FUNCTION
is Node.Import.Kind.Table -> WasmExternalKind.TABLE
@ -281,8 +299,18 @@ data class ScriptContext(
bindImport(import, false,
MethodType.methodType(funcType.ret?.jclass ?: Void.TYPE, funcType.params.map { it.jclass }))
fun resolveImportGlobal(import: Node.Import, globalType: Node.Type.Global) =
bindImport(import, true, MethodType.methodType(globalType.contentType.jclass))
fun resolveImportGlobals(import: Node.Import, globalType: Node.Type.Global): List<MethodHandle> {
val getter = bindImport(import, true, MethodType.methodType(globalType.contentType.jclass))
// Whether the setter is present or not defines whether it is mutable
val setter = try {
bindImport(import, "set" + import.field.javaIdent.capitalize(),
MethodType.methodType(Void.TYPE, globalType.contentType.jclass))
} catch (e: RunErr.ImportNotFound) { null }
// Mutability must match
if (globalType.mutable == (setter == null))
throw RunErr.ImportGlobalInvalidMutability(import.module, import.field, globalType.mutable)
return if (setter == null) listOf(getter) else listOf(getter, setter)
}
fun resolveImportMemory(import: Node.Import, memoryType: Node.Type.Memory, mem: Mem) =
bindImport(import, true, MethodType.methodType(Class.forName(mem.memType.asm.className))).
@ -293,10 +321,15 @@ data class ScriptContext(
bindImport(import, true, MethodType.methodType(Array<MethodHandle>::class.java)).
invokeWithArguments()!! as Array<MethodHandle>
open class SimpleClassLoader(parent: ClassLoader, logger: Logger) : ClassLoader(parent), Logger by logger {
open class SimpleClassLoader(
parent: ClassLoader,
logger: Logger,
val splitWhenTooLarge: Boolean = true
) : ClassLoader(parent), Logger by logger {
fun fromBuiltContext(ctx: ClsContext): Class<*> {
trace { "Computing frames for ASM class:\n" + ctx.cls.toAsmString() }
return ctx.cls.withComputedFramesAndMaxs().let { bytes ->
val writer = if (splitWhenTooLarge) AsmToBinary else AsmToBinary.noSplit
return writer.fromClassNode(ctx.cls).let { bytes ->
debug { "ASM class:\n" + bytes.asClassNode().toAsmString() }
defineClass("${ctx.packageName}.${ctx.className}", bytes, 0, bytes.size)
}
@ -313,4 +346,4 @@ data class ScriptContext(
defineClass(className, bytes, 0, bytes.size)
}
}
}
}

View File

@ -3,6 +3,8 @@ package asmble.run.jvm
import asmble.annotation.WasmExport
import asmble.annotation.WasmExternalKind
import asmble.compile.jvm.Mem
import asmble.compile.jvm.MemoryBuffer
import asmble.compile.jvm.MemoryByteBuffer
import java.io.PrintWriter
import java.lang.invoke.MethodHandle
import java.nio.ByteBuffer
@ -17,10 +19,10 @@ open class TestHarness(val out: PrintWriter) {
val global_f32 = 666.6f
val global_f64 = 666.6
val table = arrayOfNulls<MethodHandle>(10)
val memory = ByteBuffer.
val memory = MemoryByteBuffer(ByteBuffer.
allocateDirect(2 * Mem.PAGE_SIZE).
order(ByteOrder.LITTLE_ENDIAN).
limit(Mem.PAGE_SIZE) as ByteBuffer
limit(Mem.PAGE_SIZE) as ByteBuffer) as MemoryBuffer
// Note, we have all of these overloads because my import method
// resolver is simple right now and only finds exact methods via

View File

@ -13,10 +13,10 @@ class SpecTestUnit(name: String, wast: String, expectedOutput: String?) : BaseTe
override val shouldFail get() = name.endsWith(".fail")
override val defaultMaxMemPages get() = when (name) {
"nop"-> 20
"resizing" -> 830
"nop" -> 20
"memory_grow" -> 830
"imports" -> 5
else -> 1
else -> 2
}
override fun warningInsteadOfErrReason(t: Throwable) = when (name) {

View File

@ -0,0 +1,33 @@
package asmble.ast
import asmble.SpecTestUnit
import asmble.TestBase
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class StackTest(val unit: SpecTestUnit) : TestBase() {
@Test
fun testStack() {
// If it's not a module expecting an error, we'll try to walk the stack on each function
unit.script.commands.mapNotNull { it as? Script.Cmd.Module }.forEach { mod ->
mod.module.funcs.filter { it.instructions.isNotEmpty() }.forEach { func ->
debug { "Func: ${func.type}" }
var indexCount = 0
Stack.walkStrict(mod.module, func) { stack, insn ->
debug { " After $insn (next: ${func.instructions.getOrNull(++indexCount)}, " +
"unreach depth: ${stack.unreachableUntilNextEndCount})" }
debug { " " + stack.current }
}
}
}
}
companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }//.filter { it.name == "loop" }
}
}

View File

@ -0,0 +1,44 @@
package asmble.compile.jvm
import asmble.TestBase
import asmble.ast.Node
import asmble.run.jvm.ScriptContext
import asmble.util.get
import org.junit.Test
import java.util.*
import kotlin.test.assertEquals
class LargeDataTest : TestBase() {
@Test
fun testLargeData() {
// This previously failed because string constants can't be longer than 65536 chars.
// We create a byte array across the whole gambit of bytes to test UTF8 encoding.
val bytesExpected = ByteArray(70000) { ((it % 255) - Byte.MIN_VALUE).toByte() }
val mod = Node.Module(
memories = listOf(Node.Type.Memory(
limits = Node.ResizableLimits(initial = 2, maximum = 2)
)),
data = listOf(Node.Data(
index = 0,
offset = listOf(Node.Instr.I32Const(0)),
data = bytesExpected
))
)
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
mod = mod,
logger = logger
)
AstToAsm.fromModule(ctx)
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
// Instantiate it, get the memory out, and check it
val field = cls.getDeclaredField("memory").apply { isAccessible = true }
val buf = field[cls.newInstance()] as MemoryByteBuffer
// Grab all + 1 and check values
val bytesActual = ByteArray(70001).also { buf.bb.get(0, it) }
bytesActual.forEachIndexed { index, byte ->
assertEquals(if (index == 70000) 0.toByte() else bytesExpected[index], byte)
}
}
}

View File

@ -0,0 +1,38 @@
package asmble.compile.jvm
import asmble.TestBase
import asmble.io.SExprToAst
import asmble.io.StrToSExpr
import asmble.run.jvm.ScriptContext
import org.junit.Test
import java.util.*
class NamesTest : TestBase() {
@Test
fun testNames() {
// Compile and make sure the names are set right
val (_, mod) = SExprToAst.toModule(StrToSExpr.parseSingleMulti("""
(module ${'$'}mod_name
(import "foo" "bar" (func ${'$'}import_func (param i32)))
(type ${'$'}some_sig (func (param ${'$'}type_param i32)))
(func ${'$'}some_func
(type ${'$'}some_sig)
(param ${'$'}func_param i32)
(local ${'$'}func_local0 i32)
(local ${'$'}func_local1 f64)
)
)
""".trimIndent()))
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
mod = mod,
logger = logger
)
AstToAsm.fromModule(ctx)
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
// Make sure the import field and the func are present named
cls.getDeclaredField("import_func")
cls.getDeclaredMethod("some_func", Integer.TYPE)
}
}

View File

@ -1,6 +1,7 @@
package asmble.io
import asmble.SpecTestUnit
import asmble.TestBase
import asmble.ast.Node
import asmble.ast.Script
import asmble.util.Logger
@ -13,12 +14,10 @@ import java.io.ByteArrayOutputStream
import kotlin.test.assertEquals
@RunWith(Parameterized::class)
class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO) {
class IoTest(val unit: SpecTestUnit) : TestBase() {
@Test
fun testIo() {
// Ignore things that are supposed to fail
if (unit.shouldFail) return
// Go from the AST to binary then back to AST then back to binary and confirm values are as expected
val ast1 = unit.script.commands.mapNotNull { (it as? Script.Cmd.Module)?.module?.also {
trace { "AST from script:\n" + SExprToStr.fromSExpr(AstToSExpr.fromModule(it)) }
@ -46,7 +45,8 @@ class IoTest(val unit: SpecTestUnit) : Logger by Logger.Print(Logger.Level.INFO)
}
companion object {
// Only tests that shouldn't fail
@JvmStatic @Parameterized.Parameters(name = "{0}")
fun data() = SpecTestUnit.allUnits
fun data() = SpecTestUnit.allUnits.filterNot { it.shouldFail }
}
}

View File

@ -0,0 +1,47 @@
package asmble.io
import asmble.ast.Node
import org.junit.Test
import kotlin.test.assertEquals
class NamesTest {
@Test
fun testNames() {
// First, make sure it can parse from sexpr
val (_, mod1) = SExprToAst.toModule(StrToSExpr.parseSingleMulti("""
(module ${'$'}mod_name
(import "foo" "bar" (func ${'$'}import_func (param i32)))
(type ${'$'}some_sig (func (param ${'$'}type_param i32)))
(func ${'$'}some_func
(type ${'$'}some_sig)
(param ${'$'}func_param i32)
(local ${'$'}func_local0 i32)
(local ${'$'}func_local1 f64)
)
)
""".trimIndent()))
val expected = Node.NameSection(
moduleName = "mod_name",
funcNames = mapOf(
0 to "import_func",
1 to "some_func"
),
localNames = mapOf(
1 to mapOf(
0 to "func_param",
1 to "func_local0",
2 to "func_local1"
)
)
)
assertEquals(expected, mod1.names)
// Now back to binary and then back and make sure it's still there
val bytes = AstToBinary.fromModule(mod1)
val mod2 = BinaryToAst.toModule(bytes)
assertEquals(expected, mod2.names)
// Now back to sexpr and then back to make sure the sexpr writer works
val sexpr = AstToSExpr.fromModule(mod2)
val (_, mod3) = SExprToAst.toModule(sexpr)
assertEquals(expected, mod3.names)
}
}

View File

@ -0,0 +1,69 @@
package asmble.run.jvm
import asmble.TestBase
import asmble.ast.Node
import asmble.compile.jvm.AstToAsm
import asmble.compile.jvm.ClsContext
import asmble.compile.jvm.MemoryByteBuffer
import org.junit.Assert
import org.junit.Test
import org.objectweb.asm.MethodTooLargeException
import java.util.*
import kotlin.test.assertEquals
class LargeFuncTest : TestBase() {
@Test
fun testLargeFunc() {
val numInsnChunks = 6001
// Make large func that does some math
val ctx = ClsContext(
packageName = "test",
className = "Temp" + UUID.randomUUID().toString().replace("-", ""),
logger = logger,
mod = Node.Module(
memories = listOf(Node.Type.Memory(Node.ResizableLimits(initial = 4, maximum = 4))),
funcs = listOf(Node.Func(
type = Node.Type.Func(params = emptyList(), ret = null),
locals = emptyList(),
instructions = (0 until numInsnChunks).flatMap {
listOf<Node.Instr>(
Node.Instr.I32Const(it * 4),
// Let's to i * (i = 1)
Node.Instr.I32Const(it),
Node.Instr.I32Const(it - 1),
Node.Instr.I32Mul,
Node.Instr.I32Store(0, 0)
)
}
)),
names = Node.NameSection(
moduleName = null,
funcNames = mapOf(0 to "someFunc"),
localNames = emptyMap()
),
exports = listOf(
Node.Export("memory", Node.ExternalKind.MEMORY, 0),
Node.Export("someFunc", Node.ExternalKind.FUNCTION, 0)
)
)
)
// Compile it
AstToAsm.fromModule(ctx)
// Confirm the method size is too large
try {
ScriptContext.SimpleClassLoader(javaClass.classLoader, logger, splitWhenTooLarge = false).
fromBuiltContext(ctx)
Assert.fail()
} catch (e: MethodTooLargeException) { }
// Try again with split
val cls = ScriptContext.SimpleClassLoader(javaClass.classLoader, logger).fromBuiltContext(ctx)
// Create it and check that it still does what we expect
val inst = cls.newInstance()
// Run someFunc
cls.getMethod("someFunc").invoke(inst)
// Get the memory out
val mem = cls.getMethod("getMemory").invoke(inst) as MemoryByteBuffer
// Read out the mem values
(0 until numInsnChunks).forEach { assertEquals(it * (it - 1), mem.getInt(it * 4)) }
}
}

View File

@ -0,0 +1,51 @@
package asmble.run.jvm
import asmble.TestBase
import org.junit.Test
import java.io.PrintWriter
import java.io.StringWriter
import kotlin.test.assertEquals
class LoggerModuleTest : TestBase() {
@Test
fun writeAndFlushTest() {
val stream = StringWriter()
val logger = LoggerModule(PrintWriter(stream))
logger.flush() // checks that no raise error
val testString = "test String for log to stdout"
for (byte: Byte in testString.toByteArray()) {
logger.write(byte.toInt())
}
logger.flush()
val loggedString = stream.toString()
assertEquals(testString, loggedString)
}
@Test
fun writeAndFlushMoreThanLoggerBufferTest() {
val stream = StringWriter()
// logger buffer has 65Kb size
val logger = LoggerModule(PrintWriter(stream))
val testString = longString(65_000 * 2) // twice as much as logger buffer
for (byte: Byte in testString.toByteArray()) {
logger.write(byte.toInt())
}
logger.flush()
val loggedString = stream.toString()
assertEquals(testString, loggedString)
}
private fun longString(size: Int): String {
val stringBuffer = StringBuffer()
for (idx: Int in (1 until size)) {
stringBuffer.append((idx % Byte.MAX_VALUE).toChar())
}
return stringBuffer.toString()
}
}

View File

@ -3,6 +3,8 @@ package asmble.run.jvm
import asmble.BaseTestUnit
import asmble.TestBase
import asmble.annotation.WasmModule
import asmble.compile.jvm.MemoryBufferBuilder
import asmble.compile.jvm.MemoryByteBuffer
import asmble.io.AstToBinary
import asmble.io.AstToSExpr
import asmble.io.ByteWriter
@ -12,6 +14,7 @@ import org.junit.Test
import java.io.ByteArrayOutputStream
import java.io.OutputStreamWriter
import java.io.PrintWriter
import java.nio.ByteBuffer
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
@ -40,7 +43,10 @@ abstract class TestRunner<out T : BaseTestUnit>(val unit: T) : TestBase() {
adjustContext = { it.copy(eagerFailLargeMemOffset = false) },
defaultMaxMemPages = unit.defaultMaxMemPages,
// Include the binary data so we can check it later
includeBinaryInCompiledClass = true
includeBinaryInCompiledClass = true,
memoryBuilder = MemoryBufferBuilder { it ->
MemoryByteBuffer(ByteBuffer.allocateDirect(it))
}
).withHarnessRegistered(PrintWriter(OutputStreamWriter(out, Charsets.UTF_8), true))
// This will fail assertions as necessary

View File

@ -9,9 +9,3 @@ Compile Rust to WASM and then to the JVM. In order of complexity:
* [rust-simple](rust-simple)
* [rust-string](rust-string)
* [rust-regex](rust-regex)
### C/C++
Compile C to WASM and then to the JVM. In order of complexity:
* [c-simple](c-simple)

View File

@ -0,0 +1,9 @@
package main
import (
"fmt"
)
func main() {
fmt.Println("Hello, World!")
}

View File

@ -2,8 +2,9 @@
extern crate regex;
use std::ptr::NonNull;
use regex::Regex;
use std::heap::{Alloc, Heap, Layout};
use std::alloc::{Alloc, Global, Layout};
use std::mem;
use std::str;
@ -37,17 +38,17 @@ pub extern "C" fn match_count(r: *mut Regex, str_ptr: *mut u8, len: usize) -> us
}
#[no_mangle]
pub extern "C" fn alloc(size: usize) -> *mut u8 {
pub extern "C" fn alloc(size: usize) -> NonNull<u8> {
unsafe {
let layout = Layout::from_size_align(size, mem::align_of::<u8>()).unwrap();
Heap.alloc(layout).unwrap()
Global.alloc(layout).unwrap()
}
}
#[no_mangle]
pub extern "C" fn dealloc(ptr: *mut u8, size: usize) {
pub extern "C" fn dealloc(ptr: NonNull<u8>, size: usize) {
unsafe {
let layout = Layout::from_size_align(size, mem::align_of::<u8>()).unwrap();
Heap.dealloc(ptr, layout);
Global.dealloc(ptr, layout);
}
}

View File

@ -3,6 +3,9 @@ package asmble.examples.rustregex;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Implementation of {@link RegexLib} based on `java.util.regex`.
*/
public class JavaLib implements RegexLib<String> {
@Override
public JavaPattern compile(String str) {

View File

@ -5,6 +5,10 @@ import asmble.generated.RustRegex;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
/**
* Implementation of {@link RegexLib} based on `asmble.generated.RustRegex` that
* was composed from Rust code (see lib.rs).
*/
public class RustLib implements RegexLib<RustLib.Ptr> {
// 600 pages is enough for our use

View File

@ -1,6 +1,7 @@
#![feature(allocator_api)]
use std::heap::{Alloc, Heap, Layout};
use std::ptr::NonNull;
use std::alloc::{Alloc, Global, Layout};
use std::ffi::{CString};
use std::mem;
use std::os::raw::c_char;
@ -30,17 +31,17 @@ pub extern "C" fn prepend_from_rust(ptr: *mut u8, len: usize) -> *const c_char {
}
#[no_mangle]
pub extern "C" fn alloc(size: usize) -> *mut u8 {
pub extern "C" fn alloc(size: usize) -> NonNull<u8> {
unsafe {
let layout = Layout::from_size_align(size, mem::align_of::<u8>()).unwrap();
Heap.alloc(layout).unwrap()
Global.alloc(layout).unwrap()
}
}
#[no_mangle]
pub extern "C" fn dealloc(ptr: *mut u8, size: usize) {
pub extern "C" fn dealloc(ptr: NonNull<u8>, size: usize) {
unsafe {
let layout = Layout::from_size_align(size, mem::align_of::<u8>()).unwrap();
Heap.dealloc(ptr, layout);
Global.dealloc(ptr, layout);
}
}

View File

@ -2,6 +2,7 @@ rootProject.name = 'asmble'
include 'annotations',
'compiler',
'examples:c-simple',
'examples:go-simple',
'examples:rust-regex',
'examples:rust-simple',
'examples:rust-string'