diff --git a/src/main/kotlin/asmble/ast/Node.kt b/src/main/kotlin/asmble/ast/Node.kt index 971c092..dbff13a 100644 --- a/src/main/kotlin/asmble/ast/Node.kt +++ b/src/main/kotlin/asmble/ast/Node.kt @@ -386,6 +386,9 @@ sealed class Node { var strToOpMap = emptyMap>(); private set var classToOpMap = emptyMap, InstrOp<*>>(); private set var strToOpcodeMap = emptyMap(); private set + var opcodeToStrMap = emptyMap(); private set + + fun op(opcode: Short) = opcodeToStrMap[opcode]?.let(strToOpMap::get) ?: error("No opcode found: $opcode") init { // Can't use reification here because inline funcs not allowed in nested context :-( @@ -402,6 +405,7 @@ sealed class Node { strToOpMap += name to op classToOpMap += clazz to op strToOpcodeMap += name to opcode + opcodeToStrMap += opcode to name } opMapEntry("unreachable", 0x00, ::ControlFlowOpNoArg, Instr.Unreachable, Instr.Unreachable::class) diff --git a/src/main/kotlin/asmble/io/BinaryToAst.kt b/src/main/kotlin/asmble/io/BinaryToAst.kt index a4ccfd5..6294671 100644 --- a/src/main/kotlin/asmble/io/BinaryToAst.kt +++ b/src/main/kotlin/asmble/io/BinaryToAst.kt @@ -1,14 +1,33 @@ package asmble.io import asmble.ast.Node +import asmble.util.fromIntBits +import asmble.util.fromLongBits import asmble.util.toIntExact +import asmble.util.toUnsignedShort open class BinaryToAst(val version: Long = 0xd) { + fun toBlockType(b: ByteReader) = b.readVarInt7("block_type").toInt().let { + if (it == -0x40) null else toValueType(b, it) + } + fun toCustomSection(b: ByteReader, afterSectionId: Int) = Node.CustomSection( afterSectionId = afterSectionId, name = b.readString(), - payload = b.readBytes("payload_data", b.size - b.index) + payload = b.readBytes("payload_data") + ) + + fun toData(b: ByteReader) = Node.Data( + index = b.readVarUInt32AsInt("index"), + offset = toInitExpr(b), + data = b.readVarUInt32AsInt("size").let { b.readBytes("data", it) } + ) + + fun toElem(b: ByteReader) = Node.Elem( + index = b.readVarUInt32AsInt("index"), + offset = toInitExpr(b), + funcIndices = b.readList { it.readVarUInt32AsInt("elems") } ) fun toElemType(b: ByteReader) = b.readVarInt7("elem_type").toInt().let { @@ -18,6 +37,26 @@ open class BinaryToAst(val version: Long = 0xd) { } } + fun toExport(b: ByteReader) = Node.Export( + field = b.readString(), + kind = b.readByte("external_kind").toInt().let { + when (it) { + 0 -> Node.ExternalKind.FUNCTION + 1 -> Node.ExternalKind.TABLE + 2 -> Node.ExternalKind.MEMORY + 3 -> Node.ExternalKind.GLOBAL + else -> error("Unrecognized export kind: $it") + } + }, + index = b.readVarUInt32AsInt("index") + ) + + fun toFunc(b: ByteReader, type: Node.Type.Func) = Node.Func( + type = type, + locals = b.readList(this::toValueType), + instructions = toInstrs(b).let { it.dropLast(1).also { require(it == listOf(Node.Instr.End)) } } + ) + fun toFuncType(b: ByteReader): Node.Type.Func { require(b.readVarInt7("form").toInt() == -0x20) return Node.Type.Func( @@ -26,6 +65,8 @@ open class BinaryToAst(val version: Long = 0xd) { ) } + fun toGlobal(b: ByteReader) = Node.Global(toGlobalType(b), toInitExpr(b)) + fun toGlobalType(b: ByteReader) = Node.Type.Global( contentType = toValueType(b), mutable = b.readVarUInt1("mutability") @@ -36,7 +77,7 @@ open class BinaryToAst(val version: Long = 0xd) { field = b.readString(), kind = b.readByte("external_kind").toInt().let { when (it) { - 0 -> Node.Import.Kind.Func(b.readVarUInt32("type").toIntExact()) + 0 -> Node.Import.Kind.Func(b.readVarUInt32AsInt("type")) 1 -> Node.Import.Kind.Table(toTableType(b)) 2 -> Node.Import.Kind.Memory(toMemoryType(b)) 3 -> Node.Import.Kind.Global(toGlobalType(b)) @@ -45,6 +86,57 @@ open class BinaryToAst(val version: Long = 0xd) { } ) + fun toInitExpr(b: ByteReader) = listOf(toInstr(b)).also { require(toInstr(b) == Node.Instr.End) } + + fun toInstrs(b: ByteReader) = mutableListOf().also { while (!b.isEof) it += toInstr(b) }.toList() + + fun toInstr(b: ByteReader) = Node.InstrOp.op(b.readByte("opcode").toUnsignedShort()).let { op -> + when (op) { + is Node.InstrOp.ControlFlowOp.NoArg -> + op.create + is Node.InstrOp.ControlFlowOp.TypeArg -> + op.create(toBlockType(b)) + is Node.InstrOp.ControlFlowOp.DepthArg -> + op.create(b.readVarUInt32AsInt("relative_depth")) + is Node.InstrOp.ControlFlowOp.TableArg -> op.create( + b.readList { it.readVarUInt32AsInt("target_table") }, + b.readVarUInt32AsInt("default_target") + ) + is Node.InstrOp.CallOp.IndexArg -> + op.create(b.readVarUInt32AsInt("function_index")) + is Node.InstrOp.CallOp.IndexReservedArg -> op.create( + b.readVarUInt32AsInt("type_index"), + b.readVarUInt1("reserved") + ) + is Node.InstrOp.ParamOp.NoArg -> + op.create + is Node.InstrOp.VarOp.IndexArg -> + op.create(b.readVarUInt32AsInt("index")) + is Node.InstrOp.MemOp.AlignOffsetArg -> op.create( + b.readVarUInt32AsInt("flags"), + b.readVarUInt32("offset") + ) + is Node.InstrOp.MemOp.ReservedArg -> + op.create(b.readVarUInt1("reserved")) + is Node.InstrOp.ConstOp.IntArg -> + op.create(b.readVarInt32("value")) + is Node.InstrOp.ConstOp.LongArg -> + op.create(b.readVarInt64("value")) + is Node.InstrOp.ConstOp.FloatArg -> + op.create(Float.fromIntBits(b.readUInt32("value").toIntExact())) + is Node.InstrOp.ConstOp.DoubleArg -> + op.create(Double.fromLongBits(b.readUInt64("value").longValueExact())) + is Node.InstrOp.CompareOp.NoArg -> + op.create + is Node.InstrOp.NumOp.NoArg -> + op.create + is Node.InstrOp.ConvertOp.NoArg -> + op.create + is Node.InstrOp.ReinterpretOp.NoArg -> + op.create + } + } + fun toMemoryType(b: ByteReader) = Node.Type.Memory(toResizableLimits(b)) fun toModule(b: ByteReader): Node.Module { @@ -57,17 +149,25 @@ open class BinaryToAst(val version: Long = 0xd) { while (!b.isEof) { val sectionId = b.readVarUInt7("id").toInt() if (sectionId != 0) require(sectionId > maxSectionId).also { maxSectionId = sectionId } - sections += sectionId to b.slice("payload_data", b.readVarUInt32("payload_len").toIntExact()) + sections += sectionId to b.slice("payload_data", b.readVarUInt32AsInt("payload_len")) } // Now build the module fun readSectionList(sectionId: Int, fn: (ByteReader) -> T) = sections.find { it.first == sectionId }?.second?.readList(fn) ?: emptyList() val types = readSectionList(1, this::toFuncType) + val funcIndices = readSectionList(3) { it.readVarUInt32AsInt("types") } return Node.Module( types = types, imports = readSectionList(2, this::toImport), - tables = TODO("Keep going..."), + tables = readSectionList(4, this::toTableType), + memories = readSectionList(5, this::toMemoryType), + globals = readSectionList(6, this::toGlobal), + exports = readSectionList(7, this::toExport), + startFuncIndex = sections.find { it.first == 8 }?.second?.readVarUInt32AsInt("index"), + elems = readSectionList(9, this::toElem), + funcs = readSectionList(10) { it }.zip(funcIndices.map { types[it] }, this::toFunc), + data = readSectionList(11, this::toData), customSections = sections.foldIndexed(emptyList()) { index, customSections, (sectionId, b) -> if (sectionId != 0) customSections else { // If the last section was custom, use the last custom section's after-ID, @@ -83,14 +183,15 @@ open class BinaryToAst(val version: Long = 0xd) { fun toResizableLimits(b: ByteReader) = b.readVarUInt1("flags").let { Node.ResizableLimits( - initial = b.readVarUInt32("initial").toIntExact(), - maximum = if (it) b.readVarUInt32("maximum").toIntExact() else null + initial = b.readVarUInt32AsInt("initial"), + maximum = if (it) b.readVarUInt32AsInt("maximum") else null ) } fun toTableType(b: ByteReader) = Node.Type.Table(toElemType(b), toResizableLimits(b)) - fun toValueType(b: ByteReader) = when (b.readVarInt7("value_type").toInt()) { + fun toValueType(b: ByteReader) = toValueType(b, b.readVarInt7("value_type").toInt()) + fun toValueType(b: ByteReader, type: Int) = when (type) { -0x01 -> Node.Type.Value.I32 -0x02 -> Node.Type.Value.I64 -0x03 -> Node.Type.Value.F32 @@ -98,8 +199,9 @@ open class BinaryToAst(val version: Long = 0xd) { else -> error("Unknown value type") } - fun ByteReader.readString() = this.readVarUInt32("len").toIntExact().let { String(this.readBytes("str", it)) } + fun ByteReader.readString() = this.readVarUInt32AsInt("len").let { String(this.readBytes("str", it)) } fun ByteReader.readList(fn: (ByteReader) -> T) = (0 until this.readVarUInt32("count")).map { _ -> fn(this) } + fun ByteReader.readVarUInt32AsInt(field: String) = this.readVarUInt32(field).toIntExact() companion object : BinaryToAst() } \ No newline at end of file diff --git a/src/main/kotlin/asmble/io/ByteReader.kt b/src/main/kotlin/asmble/io/ByteReader.kt index 3da9c82..468f1ae 100644 --- a/src/main/kotlin/asmble/io/ByteReader.kt +++ b/src/main/kotlin/asmble/io/ByteReader.kt @@ -1,14 +1,17 @@ package asmble.io +import asmble.util.toIntExact +import asmble.util.toUnsignedBigInt +import asmble.util.toUnsignedLong import java.math.BigInteger +import java.nio.ByteBuffer +import java.nio.ByteOrder interface ByteReader { val isEof: Boolean - val index: Int - val size: Int fun readByte(field: String): Byte - fun readBytes(field: String, amount: Int): ByteArray + fun readBytes(field: String, amount: Int? = null): ByteArray fun readUInt32(field: String): Long fun readUInt64(field: String): BigInteger fun readVarInt7(field: String): Byte @@ -18,4 +21,73 @@ interface ByteReader { fun readVarUInt7(field: String): Short fun readVarUInt32(field: String): Long fun slice(field: String, amount: Int): ByteReader + + class Buffer(val buf: ByteBuffer) : ByteReader { + init { buf.order(ByteOrder.LITTLE_ENDIAN) } + + override val isEof get() = buf.position() == buf.limit() + + override fun readByte(field: String) = buf.get() + + override fun readBytes(field: String, amount: Int?) = + ByteArray(amount ?: buf.limit() - buf.position()).also { buf.get(it) } + + override fun readUInt32(field: String) = buf.getInt().toUnsignedLong() + + override fun readUInt64(field: String) = buf.getLong().toUnsignedBigInt() + + override fun readVarInt7(field: String) = readSignedLeb128().let { + require(it >= Byte.MIN_VALUE.toLong() && it <= Byte.MAX_VALUE.toLong()) + it.toByte() + } + + override fun readVarInt32(field: String) = readSignedLeb128().toIntExact() + + override fun readVarInt64(field: String) = readSignedLeb128() + + override fun readVarUInt1(field: String) = readUnsignedLeb128().let { + require(it == 1 || it == 0) + it == 1 + } + + override fun readVarUInt7(field: String) = readUnsignedLeb128().let { + require(it <= 255) + it.toShort() + } + + override fun readVarUInt32(field: String) = readUnsignedLeb128().toUnsignedLong() + + override fun slice(field: String, amount: Int) = ByteReader.Buffer(buf.slice().also { it.limit(amount) }) + + private fun readUnsignedLeb128(): Int { + // Taken from Android source, Apache licensed + var result = 0 + var cur = 0 + var count = 0 + do { + cur = buf.get().toInt() and 0xff + result = result or ((cur and 0x7f) shl (count * 7)) + count++ + } while (cur and 0x80 == 0x80 && count < 5) + if (cur and 0x80 == 0x80) throw NumberFormatException() + return result + } + + private fun readSignedLeb128(): Long { + // Taken from Android source, Apache licensed + var result = 0L + var cur = 0 + var count = 0 + var signBits = -1L + do { + cur = buf.get().toInt() and 0xff + result = result or ((cur and 0x7f).toLong() shl (count * 7)) + signBits = signBits shl 7 + count++ + } while (cur and 0x80 == 0x80 && count < 5) + if (cur and 0x80 == 0x80) throw NumberFormatException() + if ((signBits shr 1) and result != 0L) result = result or signBits + return result + } + } } \ No newline at end of file diff --git a/src/main/kotlin/asmble/util/NumExt.kt b/src/main/kotlin/asmble/util/NumExt.kt index 996d2fe..869d569 100644 --- a/src/main/kotlin/asmble/util/NumExt.kt +++ b/src/main/kotlin/asmble/util/NumExt.kt @@ -4,6 +4,8 @@ import java.math.BigInteger internal const val INT_MASK = 0xffffffffL +fun Byte.toUnsignedShort() = if (this >= 0) this.toShort() else (this.toInt() + 128).toShort() + fun BigInteger.unsignedToSignedLong(): Long { if (this.signum() < 0 || this.bitLength() > java.lang.Long.SIZE) throw NumberFormatException() return this.toLong() @@ -20,7 +22,7 @@ fun Float.Companion.fromIntBits(v: Int) = java.lang.Float.intBitsToFloat(v) fun Int.toUnsignedLong() = java.lang.Integer.toUnsignedLong(this) fun Long.toIntExact() = - if (this > Int.MAX_VALUE.toLong()) throw NumberFormatException() + if (this > Int.MAX_VALUE.toLong() || this < Int.MIN_VALUE.toLong()) throw NumberFormatException() else this.toInt() fun Long.toUnsignedBigInt() =