diff --git a/src/compiler.ts b/src/compiler.ts index a1654463..94d17b02 100644 --- a/src/compiler.ts +++ b/src/compiler.ts @@ -1954,7 +1954,7 @@ export class Compiler extends DiagnosticEmitter { condExpr = module.createI32(1); alwaysTrue = true; } - innerFlow.inheritNonnullIf(condExpr); + innerFlow.inheritNonnullIfTrue(condExpr); var incrExpr = statement.incrementor ? this.compileExpression(statement.incrementor, Type.void, ConversionKind.IMPLICIT, WrapMode.NONE) : 0; @@ -2041,7 +2041,7 @@ export class Compiler extends DiagnosticEmitter { // Each arm initiates a branch var ifTrueFlow = outerFlow.fork(); this.currentFlow = ifTrueFlow; - ifTrueFlow.inheritNonnullIf(condExpr); + ifTrueFlow.inheritNonnullIfTrue(condExpr); var ifTrueExpr = this.compileStatement(ifTrue); ifTrueFlow.freeScopedLocals(); this.currentFlow = outerFlow; @@ -2050,7 +2050,7 @@ export class Compiler extends DiagnosticEmitter { if (ifFalse) { let ifFalseFlow = outerFlow.fork(); this.currentFlow = ifFalseFlow; - ifFalseFlow.inheritNonnullIfNot(condExpr); + ifFalseFlow.inheritNonnullIfFalse(condExpr); ifFalseExpr = this.compileStatement(ifFalse); ifFalseFlow.freeScopedLocals(); this.currentFlow = outerFlow; @@ -2058,7 +2058,7 @@ export class Compiler extends DiagnosticEmitter { } else { outerFlow.inheritConditional(ifTrueFlow); if (ifTrueFlow.isAny(FlowFlags.ANY_TERMINATING)) { - outerFlow.inheritNonnullIfNot(condExpr); + outerFlow.inheritNonnullIfFalse(condExpr); } } return module.createIf(condExpr, ifTrueExpr, ifFalseExpr); @@ -2436,7 +2436,7 @@ export class Compiler extends DiagnosticEmitter { var continueLabel = "continue|" + label; innerFlow.continueLabel = continueLabel; - innerFlow.inheritNonnullIf(condExpr); + innerFlow.inheritNonnullIfTrue(condExpr); var body = this.compileStatement(statement.statement); var alwaysTrue = false; // TODO var terminated = innerFlow.isAny(FlowFlags.ANY_TERMINATING); @@ -4759,7 +4759,7 @@ export class Compiler extends DiagnosticEmitter { let previousFlow = this.currentFlow; let rightFlow = previousFlow.fork(); this.currentFlow = rightFlow; - rightFlow.inheritNonnullIf(leftExpr); + rightFlow.inheritNonnullIfTrue(leftExpr); rightExpr = this.compileExpression(right, leftType, ConversionKind.IMPLICIT, WrapMode.NONE); rightType = leftType; this.currentFlow = previousFlow; @@ -4809,7 +4809,7 @@ export class Compiler extends DiagnosticEmitter { let previousFlow = this.currentFlow; let rightFlow = previousFlow.fork(); this.currentFlow = rightFlow; - rightFlow.inheritNonnullIfNot(leftExpr); + rightFlow.inheritNonnullIfFalse(leftExpr); rightExpr = this.compileExpression(right, leftType, ConversionKind.IMPLICIT, WrapMode.NONE); rightType = leftType; this.currentFlow = previousFlow; @@ -5030,7 +5030,6 @@ export class Compiler extends DiagnosticEmitter { var module = this.module; var flow = this.currentFlow; var target = this.resolver.resolveExpression(expression, flow); // reports - var possiblyNull = this.currentType.is(TypeFlags.NULLABLE) && !flow.isNonnull(valueExpr); if (!target) return module.createUnreachable(); switch (target.kind) { @@ -5043,7 +5042,7 @@ export class Compiler extends DiagnosticEmitter { this.currentType = tee ? (target).type : Type.void; return module.createUnreachable(); } - return this.makeLocalAssignment(target, valueExpr, tee, possiblyNull); + return this.makeLocalAssignment(target, valueExpr, tee, !flow.isNonnull(this.currentType, valueExpr)); } case ElementKind.GLOBAL: { if (!this.compileGlobal(target)) return module.createUnreachable(); @@ -5242,7 +5241,10 @@ export class Compiler extends DiagnosticEmitter { if (!flow.canOverflow(valueExpr, type)) flow.setLocalFlag(localIndex, LocalFlags.WRAPPED); else flow.unsetLocalFlag(localIndex, LocalFlags.WRAPPED); } - if (possiblyNull) flow.unsetLocalFlag(localIndex, LocalFlags.NONNULL); + if (type.is(TypeFlags.NULLABLE)) { + if (possiblyNull) flow.unsetLocalFlag(localIndex, LocalFlags.NONNULL); + else flow.setLocalFlag(localIndex, LocalFlags.NONNULL); + } if (tee) { this.currentType = type; return this.module.createTeeLocal(localIndex, valueExpr); diff --git a/src/flow.ts b/src/flow.ts index 8c090752..4a7227d1 100644 --- a/src/flow.ts +++ b/src/flow.ts @@ -582,8 +582,12 @@ export class Flow { this.localFlags = combinedFlags; } - /** Checks if an expression is known to be non-null. */ - isNonnull(expr: ExpressionRef): bool { + /** Checks if an expression of the specified type is known to be non-null, even if the type might be nullable. */ + isNonnull(type: Type, expr: ExpressionRef): bool { + if (!type.is(TypeFlags.NULLABLE)) return true; + // below, only teeLocal/getLocal are relevant because these are the only expressions that + // depend on a dynamic nullable state (flag = LocalFlags.NONNULL), while everything else + // has already been handled by the nullable type check above. switch (getExpressionId(expr)) { case ExpressionId.SetLocal: { if (!isTeeLocal(expr)) break; @@ -598,8 +602,9 @@ export class Flow { return false; } - /** Sets local states where this branch is only taken when `expr` is true-ish. */ - inheritNonnullIf(expr: ExpressionRef): void { + /** Updates local states to reflect that this branch is only taken when `expr` is true-ish. */ + inheritNonnullIfTrue(expr: ExpressionRef): void { + // A: `expr` is true-ish -> Q: how did that happen? switch (getExpressionId(expr)) { case ExpressionId.SetLocal: { if (!isTeeLocal(expr)) break; @@ -607,7 +612,7 @@ export class Flow { this.setLocalFlag(local.index, LocalFlags.NONNULL); break; } - case ExpressionId.GetLocal: { // local must be true-ish/non-null + case ExpressionId.GetLocal: { let local = this.parentFunction.localsByIndex[getGetLocalIndex(expr)]; this.setLocalFlag(local.index, LocalFlags.NONNULL); break; @@ -615,22 +620,24 @@ export class Flow { case ExpressionId.If: { let ifFalse = getIfFalse(expr); if (!ifFalse) break; - if (getExpressionId(ifFalse) == ExpressionId.Const && getExpressionType(ifFalse) == NativeType.I32 && getConstValueI32(ifFalse) == 0) { + if (getExpressionId(ifFalse) == ExpressionId.Const) { // Logical AND: (if (condition ifTrue 0)) - // the only way this can become true is if condition and ifTrue are true - this.inheritNonnullIf(getIfCondition(expr)); - this.inheritNonnullIf(getIfTrue(expr)); + // the only way this had become true is if condition and ifTrue are true + if ( + (getExpressionType(ifFalse) == NativeType.I32 && getConstValueI32(ifFalse) == 0) || + (getExpressionType(ifFalse) == NativeType.I64 && getConstValueI64Low(ifFalse) == 0 && getConstValueI64High(ifFalse) == 0) + ) { + this.inheritNonnullIfTrue(getIfCondition(expr)); + this.inheritNonnullIfTrue(getIfTrue(expr)); + } } break; } case ExpressionId.Unary: { switch (getUnaryOp(expr)) { - case UnaryOp.EqzI32: { - this.inheritNonnullIfNot(getUnaryValue(expr)); // !expr - break; - } + case UnaryOp.EqzI32: case UnaryOp.EqzI64: { - this.inheritNonnullIfNot(getUnaryValue(expr)); // !expr + this.inheritNonnullIfFalse(getUnaryValue(expr)); // !value -> value must have been false break; } } @@ -642,9 +649,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && getConstValueI32(left) != 0) { - this.inheritNonnullIf(right); // TRUE == right + this.inheritNonnullIfTrue(right); // TRUE == right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && getConstValueI32(right) != 0) { - this.inheritNonnullIf(left); // left == TRUE + this.inheritNonnullIfTrue(left); // left == TRUE -> left must have been true } break; } @@ -652,9 +659,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && (getConstValueI64Low(left) != 0 || getConstValueI64High(left) != 0)) { - this.inheritNonnullIf(right); // TRUE == right + this.inheritNonnullIfTrue(right); // TRUE == right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && (getConstValueI64Low(right) != 0 && getConstValueI64High(right) != 0)) { - this.inheritNonnullIf(left); // left == TRUE + this.inheritNonnullIfTrue(left); // left == TRUE -> left must have been true } break; } @@ -662,9 +669,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && getConstValueI32(left) == 0) { - this.inheritNonnullIf(right); // FALSE != right + this.inheritNonnullIfTrue(right); // FALSE != right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && getConstValueI32(right) == 0) { - this.inheritNonnullIf(left); // left != FALSE + this.inheritNonnullIfTrue(left); // left != FALSE -> left must have been true } break; } @@ -672,9 +679,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && getConstValueI64Low(left) == 0 && getConstValueI64High(left) == 0) { - this.inheritNonnullIf(right); // FALSE != right + this.inheritNonnullIfTrue(right); // FALSE != right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && getConstValueI64Low(right) == 0 && getConstValueI64High(right) == 0) { - this.inheritNonnullIf(left); // left != FALSE + this.inheritNonnullIfTrue(left); // left != FALSE -> left must have been true } break; } @@ -684,17 +691,15 @@ export class Flow { } } - /** Sets local states where this branch is only taken when `expr` is false-ish. */ - inheritNonnullIfNot(expr: ExpressionRef): void { + /** Updates local states to reflect that this branch is only taken when `expr` is false-ish. */ + inheritNonnullIfFalse(expr: ExpressionRef): void { + // A: `expr` is false-ish -> Q: how did that happen? switch (getExpressionId(expr)) { case ExpressionId.Unary: { switch (getUnaryOp(expr)) { - case UnaryOp.EqzI32: { - this.inheritNonnullIf(getUnaryValue(expr)); // !expr - break; - } + case UnaryOp.EqzI32: case UnaryOp.EqzI64: { - this.inheritNonnullIf(getUnaryValue(expr)); // !expr + this.inheritNonnullIfTrue(getUnaryValue(expr)); // !value -> value must have been true break; } } @@ -702,25 +707,32 @@ export class Flow { } case ExpressionId.If: { let ifTrue = getIfTrue(expr); - if (getExpressionId(ifTrue) == ExpressionId.Const && getExpressionType(ifTrue) == NativeType.I32 && getConstValueI32(ifTrue) != 0) { + if (getExpressionId(ifTrue) == ExpressionId.Const) { let ifFalse = getIfFalse(expr); if (!ifFalse) break; // Logical OR: (if (condition 1 ifFalse)) - // the only way this can become false is if condition and ifFalse are false - this.inheritNonnullIfNot(getIfCondition(expr)); - this.inheritNonnullIfNot(getIfFalse(expr)); + // the only way this had become false is if condition and ifFalse are false + if ( + (getExpressionType(ifTrue) == NativeType.I32 && getConstValueI32(ifTrue) != 0) || + (getExpressionType(ifTrue) == NativeType.I64 && (getConstValueI64Low(ifTrue) != 0 || getConstValueI64High(ifTrue) != 0)) + ) { + this.inheritNonnullIfFalse(getIfCondition(expr)); + this.inheritNonnullIfFalse(getIfFalse(expr)); + } + } break; } case ExpressionId.Binary: { switch (getBinaryOp(expr)) { + // remember: we want to know how the _entire_ expression became FALSE (!) case BinaryOp.EqI32: { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && getConstValueI32(left) == 0) { - this.inheritNonnullIf(right); // FALSE == right + this.inheritNonnullIfTrue(right); // FALSE == right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && getConstValueI32(right) == 0) { - this.inheritNonnullIf(left); // left == FALSE + this.inheritNonnullIfTrue(left); // left == FALSE -> left must have been true } break; } @@ -728,9 +740,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && getConstValueI64Low(left) == 0 && getConstValueI64High(left) == 0) { - this.inheritNonnullIf(right); // FALSE == right + this.inheritNonnullIfTrue(right); // FALSE == right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && getConstValueI64Low(right) == 0 && getConstValueI64High(right) == 0) { - this.inheritNonnullIf(left); // left == FALSE + this.inheritNonnullIfTrue(left); // left == FALSE -> left must have been true } break; } @@ -738,9 +750,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && getConstValueI32(left) != 0) { - this.inheritNonnullIf(right); // TRUE != right + this.inheritNonnullIfTrue(right); // TRUE != right -> right must have been true } else if (getExpressionId(right) == ExpressionId.Const && getConstValueI32(right) != 0) { - this.inheritNonnullIf(left); // left != TRUE + this.inheritNonnullIfTrue(left); // left != TRUE -> left must have been true } break; } @@ -748,9 +760,9 @@ export class Flow { let left = getBinaryLeft(expr); let right = getBinaryRight(expr); if (getExpressionId(left) == ExpressionId.Const && (getConstValueI64Low(left) != 0 || getConstValueI64High(left) != 0)) { - this.inheritNonnullIf(right); // TRUE != right + this.inheritNonnullIfTrue(right); // TRUE != right -> right must have been true for this to become false } else if (getExpressionId(right) == ExpressionId.Const && (getConstValueI64Low(right) != 0 || getConstValueI64High(right) != 0)) { - this.inheritNonnullIf(left); // left != TRUE + this.inheritNonnullIfTrue(left); // left != TRUE -> left must have been true for this to become false } break; } diff --git a/tests/compiler/possibly-null.optimized.wat b/tests/compiler/possibly-null.optimized.wat index c83e8f1a..fe8403d3 100644 --- a/tests/compiler/possibly-null.optimized.wat +++ b/tests/compiler/possibly-null.optimized.wat @@ -20,6 +20,7 @@ (export "testLogicalOr" (func $possibly-null/testTrue)) (export "testLogicalAndMulti" (func $possibly-null/testLogicalAndMulti)) (export "testLogicalOrMulti" (func $possibly-null/testLogicalAndMulti)) + (export "testAssign" (func $possibly-null/testLogicalAndMulti)) (func $possibly-null/testTrue (; 0 ;) (type $FUNCSIG$vi) (param $0 i32) nop ) diff --git a/tests/compiler/possibly-null.ts b/tests/compiler/possibly-null.ts index 01af0a4a..101c103f 100644 --- a/tests/compiler/possibly-null.ts +++ b/tests/compiler/possibly-null.ts @@ -99,14 +99,23 @@ export function testLogicalAndMulti(a: Ref | null, b: Ref | null): void { if (a && b) { if (isNullable(a)) ERROR("should be non-nullable"); if (isNullable(b)) ERROR("should be non-nullable"); + } else { + if (!isNullable(a)) ERROR("should be nullable"); + if (!isNullable(b)) ERROR("should be nullable"); } } export function testLogicalOrMulti(a: Ref | null, b: Ref | null): void { if (!a || !b) { - // something + if (!isNullable(a)) ERROR("should be nullable"); + if (!isNullable(b)) ERROR("should be nullable"); } else { if (isNullable(a)) ERROR("should be non-nullable"); if (isNullable(b)) ERROR("should be non-nullable"); } } + +export function testAssign(a: Ref | null, b: Ref): void { + a = b; + if (isNullable(a)) ERROR("should be non-nullable"); +} diff --git a/tests/compiler/possibly-null.untouched.wat b/tests/compiler/possibly-null.untouched.wat index 5ad30504..a4be7dfb 100644 --- a/tests/compiler/possibly-null.untouched.wat +++ b/tests/compiler/possibly-null.untouched.wat @@ -23,6 +23,7 @@ (export "testLogicalOr" (func $possibly-null/testLogicalOr)) (export "testLogicalAndMulti" (func $possibly-null/testLogicalAndMulti)) (export "testLogicalOrMulti" (func $possibly-null/testLogicalOrMulti)) + (export "testAssign" (func $possibly-null/testAssign)) (func $possibly-null/testTrue (; 0 ;) (type $FUNCSIG$vi) (param $0 i32) local.get $0 if @@ -168,6 +169,8 @@ end if nop + else + nop end ) (func $possibly-null/testLogicalOrMulti (; 16 ;) (type $FUNCSIG$vii) (param $0 i32) (param $1 i32) @@ -185,6 +188,10 @@ nop end ) - (func $null (; 17 ;) (type $FUNCSIG$v) + (func $possibly-null/testAssign (; 17 ;) (type $FUNCSIG$vii) (param $0 i32) (param $1 i32) + local.get $1 + local.set $0 + ) + (func $null (; 18 ;) (type $FUNCSIG$v) ) )