diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala index 27da6b0c20..1fd1904863 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala @@ -9,7 +9,7 @@ import hkmc2.utils.* import document.* import document.Document import semantics.* -import text.Param as WasmParam +import hkmc2.codegen.wasm.text.Param as WasmParam import Instructions.* import scala.collection.mutable.{ArrayBuffer as ArrayBuf, Map as MutMap} @@ -129,6 +129,9 @@ class TypeInfo( end TypeInfo object Ctx: + enum WasmIntrinsicType: + case TupleArray(mutable: Bool) + val binaryOps: Map[Str, (Expr, Expr) => Expr] = Map( "plus_impl" -> i32.add, "minus_impl" -> i32.sub, @@ -192,6 +195,7 @@ class Ctx( import Ctx.prettyString private val wasmIntrinsicFuncs: MutMap[Str, FuncIdx] = MutMap.empty + private val wasmIntrinsicTypes: MutMap[Ctx.WasmIntrinsicType, TypeIdx] = MutMap.empty /** Adds a type into this context. */ def addType(sym: Opt[BlockMemberSymbol], typeInfo: TypeInfo): TypeIdx = @@ -324,6 +328,13 @@ class Ctx( def getOrCreateWasmIntrinsic(name: Str, createIntrinsic: => FuncIdx): FuncIdx = wasmIntrinsicFuncs.getOrElseUpdate(name, createIntrinsic) + /** + * Returns the cached [[TypeIdx]] for the intrinsic type `key`, creating it with `createType` if + * it does not yet exist in this context. + */ + def getOrCreateWasmIntrinsicType(key: Ctx.WasmIntrinsicType, createType: => TypeIdx): TypeIdx = + wasmIntrinsicTypes.getOrElseUpdate(key, createType) + def toWat: Document = doc"(module #{ # ${(types.toSeq ++ funcs.toSeq).map(_.toWat).mkDocument(doc" # ")}) #} " diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala index 2889176599..e19be10393 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala @@ -243,6 +243,30 @@ object Instructions: stackargs = Seq(arrayRef), resultType = S(I32Type) ) + + /** Creates an `array.new_fixed` instruction. */ + def new_fixed(arrayType: TypeIdx, items: Seq[Expr]): FoldedInstr = FoldedInstr( + mnemonic = "array.new_fixed", + instrargs = Seq(arrayType.toWat, doc"${items.length}"), + stackargs = items, + resultType = S(RefType(arrayType, nullable = false)) + ) + + /** Creates an `array.get` instruction. */ + def get(arrayType: TypeIdx, arrayRef: Expr, index: Expr, elemType: Type): FoldedInstr = FoldedInstr( + mnemonic = "array.get", + instrargs = Seq(arrayType.toWat), + stackargs = Seq(arrayRef, index), + resultType = S(elemType) + ) + + /** Creates an `array.set` instruction. */ + def set(arrayType: TypeIdx, arrayRef: Expr, index: Expr, value: Expr): FoldedInstr = FoldedInstr( + mnemonic = "array.set", + instrargs = Seq(arrayType.toWat), + stackargs = Seq(arrayRef, index, value), + resultType = N + ) end array object ref: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Wasm.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Wasm.scala index ebfd332b18..0ca1e0ac6c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Wasm.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Wasm.scala @@ -135,8 +135,19 @@ case class StructType( def toWat: Document = doc"(struct${fieldSeq.map(_.toWat).mkDocument(doc" ").surroundUnlessEmpty(doc" ")})" +/** A type representing an array type. */ +case class ArrayType( + elemType: Type, + mutable: Bool, +) extends ToWat: + private def elemDoc: Document = + if mutable then doc"(mut ${elemType.toWat})" else elemType.toWat + + def toWat: Document = + doc"(array ${elemDoc})" + /** A composite type. */ -type CompType = StructType | FunctionType +type CompType = StructType | FunctionType | ArrayType type AbsHeapType = HeapType.Func.type diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index c9c39b0a2f..f0e6fea4aa 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -37,6 +37,8 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private val baseObjectSym: BlockMemberSymbol = BlockMemberSymbol("Object", Nil) private val tagFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$tag")) + private case class ActiveLabel(sym: Local, breakLabel: Str, continueLabel: Opt[Str]) + private var activeLabels: List[ActiveLabel] = Nil private def baseObjectTypeIdx(using Ctx): TypeIdx = ctx.getType_!(baseObjectSym) @@ -49,6 +51,79 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: private def baseObjectRefType(nullable: Bool)(using Ctx): RefType = RefType(baseObjectTypeIdx, nullable = nullable) + private def tupleArrayType(mut: Bool)(using Ctx): TypeIdx = + ctx.getOrCreateWasmIntrinsicType( + Ctx.WasmIntrinsicType.TupleArray(mutable = mut), + createType = + val suffix = if mut then "Mut" else "" + val sym = BlockMemberSymbol(s"TupleArray$suffix", Nil) + ctx.addType( + sym = S(sym), + TypeInfo( + sym, + ArrayType( + elemType = RefType.anyref, + mutable = mut + ) + ) + ) + ) + + private def tupleArrayGet( + tupleExpr: Expr, + idxBuilder: Expr => Expr + )(using Ctx, Raise, Scope): Expr = + val elemType = RefType.anyref + val mutArrayType = tupleArrayType(true) + val immArrayType = tupleArrayType(false) + val tupleIsMutable = ref.test(tupleExpr, RefType(mutArrayType, nullable = true)) + val mutableBranch = + val tupleRef = ref.cast(tupleExpr, RefType(mutArrayType, nullable = false)) + array.get(mutArrayType, tupleRef, idxBuilder(tupleRef), elemType) + val immutableBranch = + val tupleRef = ref.cast(tupleExpr, RefType(immArrayType, nullable = false)) + array.get(immArrayType, tupleRef, idxBuilder(tupleRef), elemType) + Instructions.`if`( + condition = tupleIsMutable, + ifTrue = mutableBranch, + ifFalse = S(immutableBranch), + resultTypes = Seq(Result(elemType.asValType_!)) + ) + + private def compileTupleIndex( + fld: Path, + loc: Opt[Loc], + errCtx: Str, + extra: => Str + )(using Ctx, Raise, Scope): Expr => Expr = + fld match + case Value.Lit(IntLit(value)) if value.isValidInt => + val idx = value.toInt + tupleRef => + if idx >= 0 then i32.const(idx) + else i32.add(array.len(tupleRef), i32.const(idx)) + case _ => + val rawIdx = result(fld) + val idxI32 = rawIdx.resultType match + case S(I32Type) => rawIdx + case S(RefType(HeapType.I31, _)) => i31.get(rawIdx, signed = true) + case S(RefType(HeapType.Any, _)) => + val casted = ref.cast(rawIdx, RefType.i31ref) + i31.get(casted, signed = true) + case ty => + return (_: Expr) => errExpr( + msg"$errCtx expects an integer index but found ${ty.fold("(none)")(_.toWat.mkString())}" -> loc + :: Nil, + extraInfo = S(extra) + ) + tupleRef => + Instructions.`if`( + condition = i32.lt_s(idxI32, i32.const(0)), + ifTrue = i32.add(idxI32, array.len(tupleRef)), + ifFalse = S(idxI32), + resultTypes = Seq(Result(I32Type)) + ) + /** * Raises a [[WarningReport]] with the given `warnMsgs` and `extraInfo`, and emits an * `unreachable` instruction. @@ -237,25 +312,47 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: case sel @ Select(qual, id) => val qualRes = result(qual) - val selSym = sel.symbol getOrElse: - lastWords(s"Symbol for Select(...) expression must be resolved") - val selTrmSym = selSym match - case termSym: TermSymbol => termSym - case sym => lastWords( - s"Expected resolved Select(...) expression to be a TermSymbol, but got $sym (${sym.getClass.getName})" + sel.symbol match + case S(selSym: TermSymbol) => + val selOwner = selSym.owner getOrElse: + lastWords(s"Expected resolved Select(...) expression `$selSym` to have an owner") + val selCls = selOwner.asBlkMember getOrElse: + lastWords( + s"Expected resolved class for Select(...) expression to be a BlockMemberSymbol, but got $selOwner (${selOwner.getClass.getName})" + ) + val fieldidx = fieldSelect(selCls, selSym) + struct.get( + fieldidx, + ref = ref.cast(qualRes, RefType(ctx.getType_!(selCls), nullable = false)), + ty = RefType.anyref ) - val selOwner = selTrmSym.owner getOrElse: - lastWords(s"Expected resolved Select(...) expression `$selTrmSym` to have an owner") - val selCls = selOwner.asBlkMember getOrElse: - lastWords( - s"Expected resolved class for Select(...) expression to be a BlockMemberSymbol, but got $selOwner (${selOwner.getClass.getName})" + case S(otherSym) => + lastWords( + s"Expected resolved Select(...) expression to be a TermSymbol, but got $otherSym (${otherSym.getClass.getName})" + ) + case N => + errExpr( + Ls( + msg"WatBuilder::result for field selection without a resolved symbol is not implemented (field `${id.name}`). Use `_.[_]` for index-based accesses." -> sel.toLoc + ), + extraInfo = S(sel) + ) + + case dyn @ DynSelect(qual, fld, arrayIdx) => + val qualRes = result(qual) + if arrayIdx then + val idxBuilder = compileTupleIndex( + fld = fld, + loc = fld.toLoc, + errCtx = "WatBuilder::result for array-style dynamic selections", + extra = dyn.toString + ) + tupleArrayGet(qualRes, idxBuilder) + else + errExpr( + Ls(msg"WatBuilder::result for dynamic field selections is not implemented yet" -> dyn.toLoc), + extraInfo = S(dyn) ) - val fieldidx = fieldSelect(selCls, selSym) - struct.get( - fieldidx, - ref = ref.cast(qualRes, RefType(ctx.getType_!(selCls), nullable = false)), - ty = RefType.anyref - ) case Instantiate(_, cls, as) => val ctorClsSymOpt = cls match @@ -287,6 +384,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: val objType = ctx.getFuncInfo_!(ctorFuncIdx).body.resultType_! call(funcidx = ctorFuncIdx, as.map(argument), Seq(Result(objType.asValType_!))) + case Tuple(mut, elems) => + val tupleValues = elems.map(argument) + array.new_fixed(tupleArrayType(mut), tupleValues) + case r => errExpr( Ls(msg"WatBackend::result for expression not implemented yet" -> r.toLoc), @@ -430,13 +531,73 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: lastWords( s"Expected `global.*` or `local.*` when compiling instruction for `$l`, but got ${lExpr.mnemonic}" ) - val rstBlk = returningTerm(rst) + val rstBlk = returningTerm(rst) Instructions.block( label = N, children = Seq(assignExpr, rstBlk), - resultTypes = rstBlk.resultTypes.map: ty => - Result(if ty is UnreachableType then RefType.anyref else ty.asValType_!) + resultTypes = rstBlk.resultTypes.map(r => Result(r.asValType_!)) + ) + + case assign @ AssignField(lhs, nme, rhs, rst) => + val lhsExpr = result(lhs) + val rhsExpr = result(rhs) + val assignInstr = assign.symbol match + case S(selSym: TermSymbol) => + val selOwner = selSym.owner getOrElse + lastWords(s"Expected resolved AssignField(...) expression `$selSym` to have an owner") + val selCls = selOwner.asBlkMember getOrElse + lastWords( + s"Expected resolved class for AssignField(...) expression to be a BlockMemberSymbol, but got $selOwner (${selOwner.getClass.getName})" + ) + val fieldidx = fieldSelect(selCls, selSym) + val objRef = ref.cast(lhsExpr, RefType(ctx.getType_!(selCls), nullable = false)) + struct.set(fieldidx, objRef, rhsExpr) + case S(otherSym) => + lastWords( + s"Expected resolved AssignField(...) expression to be a TermSymbol, but got $otherSym (${otherSym.getClass.getName})" + ) + case N => + errExpr( + Ls( + msg"WatBuilder::returningTerm for AssignField(...) without a resolved symbol is not implemented (field `${nme.name}`). Use `_.[_]` for index-based accesses." -> nme.toLoc + ), + extraInfo = S(assign) + ) + + val rstBlk = returningTerm(rst) + Instructions.block( + label = N, + children = Seq(assignInstr, rstBlk), + resultTypes = rstBlk.resultTypes.map(r => Result(r.asValType_!)) + ) + + case assign @ AssignDynField(lhs, fld, arrayIdx, rhs, rst) => + val lhsExpr = result(lhs) + val rhsExpr = result(rhs) + val assignInstr = + if arrayIdx then + val tupleArrayType = this.tupleArrayType(mut = true) + val tupleRef = ref.cast(lhsExpr, RefType(tupleArrayType, nullable = false)) + val idxBuilder = compileTupleIndex( + fld = fld, + loc = fld.toLoc, + errCtx = "WatBuilder::returningTerm for AssignDynField(...)", + extra = assign.toString + ) + val idxExpr = idxBuilder(tupleRef) + array.set(tupleArrayType, tupleRef, idxExpr, rhsExpr) + else + errExpr( + Ls(msg"WatBuilder::returningTerm for AssignDynField(...) where `arrayIdx = false` is not implemented yet" -> lhs.toLoc), + extraInfo = S(assign) + ) + + val rstBlk = returningTerm(rst) + Instructions.block( + label = N, + children = Seq(assignInstr, rstBlk), + resultTypes = rstBlk.resultTypes.map(r => Result(r.asValType_!)) ) case Define(defn, rst) => @@ -787,6 +948,32 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: ifFalse = N, resultTypes = Seq.empty )) + case Case.Tup(len, inf) => + val arrayRefType = RefType(HeapType.Array, nullable = true) + val isArrayTest = ref.test(getScrutExpr, arrayRefType) + + // Length check + val scrutArray = ref.cast(getScrutExpr, arrayRefType) + val arrayLength = array.len(scrutArray) + val lengthTest = if inf then + i32.ge_u(arrayLength, i32.const(len)) + else + i32.eq(arrayLength, i32.const(len)) + + val testExpr = i32.and(isArrayTest, lengthTest) + val bodyExpr = returningTerm(body) + val armLabelSym = TempSymbol(N, "arm") + val armLabel = scope.allocateName(armLabelSym) + S(Instructions.`if`( + condition = testExpr, + ifTrue = Instructions.block( + label = S(armLabel), + children = Seq(bodyExpr, br(matchLabel)), + resultTypes = Seq.empty + ), + ifFalse = N, + resultTypes = Seq.empty + )) case _ => break(errExpr( Ls( diff --git a/hkmc2/shared/src/test/mlscript/wasm/Tuples.mls b/hkmc2/shared/src/test/mlscript/wasm/Tuples.mls new file mode 100644 index 0000000000..ab1c865e45 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/wasm/Tuples.mls @@ -0,0 +1,87 @@ + +:global +:wasm + +[] +//│ Wasm result: +//│ = {} + +[0, 1, 2] +//│ Wasm result: +//│ = {} + +fun makePair(x, y) = [x, y] +makePair(1, 2).[1] +//│ Wasm result: +//│ = 2 + +class Foo(val x, val y) +let b = [Foo(0, 1), true, 2] +if b.[0] is Foo(x, y) then x else 3 +//│ Wasm result: +//│ = 0 + +let nums = [1, 2, 3] +nums.[-1] +//│ Wasm result: +//│ = 3 + +let nums = mut [1, 2, 3] +set nums.[0] = 0 +nums.[0] +//│ Wasm result: +//│ = 0 + +:re +let nums = [1, 2, 3] +set nums.[0] = 0 +nums.[0] +//│ Wasm result: +//│ ═══[RUNTIME ERROR] RuntimeError: illegal cast + +:re +let nums = [1, 2, 3] +nums.[3] +//│ Wasm result: +//│ ═══[RUNTIME ERROR] RuntimeError: array element access out of bounds + + +:wat +let c = [10, 20, 30, 40] +c.[3] +//│ Wat: +//│ (module +//│ (type $Object (sub (struct (field $$tag (mut i32))))) +//│ (type $TupleArray (array (ref null any))) +//│ (type $TupleArrayMut (array (mut (ref null any)))) +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 3) (result (ref null any)) +//│ (local $c (ref null any)) +//│ (block (result (ref null any)) +//│ (local.set $c +//│ (array.new_fixed $TupleArray 4 +//│ (ref.i31 +//│ (i32.const 10)) +//│ (ref.i31 +//│ (i32.const 20)) +//│ (ref.i31 +//│ (i32.const 30)) +//│ (ref.i31 +//│ (i32.const 40)))) +//│ (if (result (ref null any)) +//│ (ref.test (ref null $TupleArrayMut) +//│ (local.get $c)) +//│ (then +//│ (array.get $TupleArrayMut +//│ (ref.cast (ref $TupleArrayMut) +//│ (local.get $c)) +//│ (i32.const 3))) +//│ (else +//│ (array.get $TupleArray +//│ (ref.cast (ref $TupleArray) +//│ (local.get $c)) +//│ (i32.const 3)))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 40