Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import hkmc2.utils.*
import document.*
import document.Document
import semantics.*
import text.Param as WasmParam
import hkmc2.codegen.wasm.text.Param as WasmParam
import Instructions.*

import scala.collection.mutable.{ArrayBuffer as ArrayBuf, Map as MutMap}
Expand Down Expand Up @@ -129,6 +129,9 @@ class TypeInfo(
end TypeInfo

object Ctx:
enum WasmIntrinsicType:
case TupleArray(mutable: Bool)

val binaryOps: Map[Str, (Expr, Expr) => Expr] = Map(
"plus_impl" -> i32.add,
"minus_impl" -> i32.sub,
Expand Down Expand Up @@ -192,6 +195,7 @@ class Ctx(
import Ctx.prettyString

private val wasmIntrinsicFuncs: MutMap[Str, FuncIdx] = MutMap.empty
private val wasmIntrinsicTypes: MutMap[Ctx.WasmIntrinsicType, TypeIdx] = MutMap.empty

/** Adds a type into this context. */
def addType(sym: Opt[BlockMemberSymbol], typeInfo: TypeInfo): TypeIdx =
Expand Down Expand Up @@ -324,6 +328,13 @@ class Ctx(
def getOrCreateWasmIntrinsic(name: Str, createIntrinsic: => FuncIdx): FuncIdx =
wasmIntrinsicFuncs.getOrElseUpdate(name, createIntrinsic)

/**
* Returns the cached [[TypeIdx]] for the intrinsic type `key`, creating it with `createType` if
* it does not yet exist in this context.
*/
def getOrCreateWasmIntrinsicType(key: Ctx.WasmIntrinsicType, createType: => TypeIdx): TypeIdx =
wasmIntrinsicTypes.getOrElseUpdate(key, createType)

def toWat: Document =
doc"(module #{ # ${(types.toSeq ++ funcs.toSeq).map(_.toWat).mkDocument(doc" # ")}) #} "

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,30 @@ object Instructions:
stackargs = Seq(arrayRef),
resultType = S(I32Type)
)

/** Creates an `array.new_fixed` instruction. */
def new_fixed(arrayType: TypeIdx, items: Seq[Expr]): FoldedInstr = FoldedInstr(
mnemonic = "array.new_fixed",
instrargs = Seq(arrayType.toWat, doc"${items.length}"),
stackargs = items,
resultType = S(RefType(arrayType, nullable = false))
)

/** Creates an `array.get` instruction. */
def get(arrayType: TypeIdx, arrayRef: Expr, index: Expr, elemType: Type): FoldedInstr = FoldedInstr(
mnemonic = "array.get",
instrargs = Seq(arrayType.toWat),
stackargs = Seq(arrayRef, index),
resultType = S(elemType)
)

/** Creates an `array.set` instruction. */
def set(arrayType: TypeIdx, arrayRef: Expr, index: Expr, value: Expr): FoldedInstr = FoldedInstr(
mnemonic = "array.set",
instrargs = Seq(arrayType.toWat),
stackargs = Seq(arrayRef, index, value),
resultType = N
)
end array

object ref:
Expand Down
13 changes: 12 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Wasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,19 @@ case class StructType(
def toWat: Document =
doc"(struct${fieldSeq.map(_.toWat).mkDocument(doc" ").surroundUnlessEmpty(doc" ")})"

/** A type representing an array type. */
case class ArrayType(
elemType: Type,
mutable: Bool,
) extends ToWat:
private def elemDoc: Document =
if mutable then doc"(mut ${elemType.toWat})" else elemType.toWat

def toWat: Document =
doc"(array ${elemDoc})"

/** A composite type. */
type CompType = StructType | FunctionType
type CompType = StructType | FunctionType | ArrayType

type AbsHeapType =
HeapType.Func.type
Expand Down
227 changes: 207 additions & 20 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:

private val baseObjectSym: BlockMemberSymbol = BlockMemberSymbol("Object", Nil)
private val tagFieldSym: TermSymbol = TermSymbol(syntax.MutVal, owner = N, Ident("$tag"))
private case class ActiveLabel(sym: Local, breakLabel: Str, continueLabel: Opt[Str])
private var activeLabels: List[ActiveLabel] = Nil

private def baseObjectTypeIdx(using Ctx): TypeIdx =
ctx.getType_!(baseObjectSym)
Expand All @@ -49,6 +51,79 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
private def baseObjectRefType(nullable: Bool)(using Ctx): RefType =
RefType(baseObjectTypeIdx, nullable = nullable)

private def tupleArrayType(mut: Bool)(using Ctx): TypeIdx =
ctx.getOrCreateWasmIntrinsicType(
Ctx.WasmIntrinsicType.TupleArray(mutable = mut),
createType =
val suffix = if mut then "Mut" else ""
val sym = BlockMemberSymbol(s"TupleArray$suffix", Nil)
ctx.addType(
sym = S(sym),
TypeInfo(
sym,
ArrayType(
elemType = RefType.anyref,
mutable = mut
)
)
)
)

private def tupleArrayGet(
tupleExpr: Expr,
idxBuilder: Expr => Expr
)(using Ctx, Raise, Scope): Expr =
val elemType = RefType.anyref
val mutArrayType = tupleArrayType(true)
val immArrayType = tupleArrayType(false)
val tupleIsMutable = ref.test(tupleExpr, RefType(mutArrayType, nullable = true))
val mutableBranch =
val tupleRef = ref.cast(tupleExpr, RefType(mutArrayType, nullable = false))
array.get(mutArrayType, tupleRef, idxBuilder(tupleRef), elemType)
val immutableBranch =
val tupleRef = ref.cast(tupleExpr, RefType(immArrayType, nullable = false))
array.get(immArrayType, tupleRef, idxBuilder(tupleRef), elemType)
Instructions.`if`(
condition = tupleIsMutable,
ifTrue = mutableBranch,
ifFalse = S(immutableBranch),
resultTypes = Seq(Result(elemType.asValType_!))
)

private def compileTupleIndex(
fld: Path,
loc: Opt[Loc],
errCtx: Str,
extra: => Str
)(using Ctx, Raise, Scope): Expr => Expr =
fld match
case Value.Lit(IntLit(value)) if value.isValidInt =>
val idx = value.toInt
tupleRef =>
if idx >= 0 then i32.const(idx)
else i32.add(array.len(tupleRef), i32.const(idx))
case _ =>
val rawIdx = result(fld)
val idxI32 = rawIdx.resultType match
case S(I32Type) => rawIdx
case S(RefType(HeapType.I31, _)) => i31.get(rawIdx, signed = true)
case S(RefType(HeapType.Any, _)) =>
val casted = ref.cast(rawIdx, RefType.i31ref)
i31.get(casted, signed = true)
case ty =>
return (_: Expr) => errExpr(
msg"$errCtx expects an integer index but found ${ty.fold("(none)")(_.toWat.mkString())}" -> loc
:: Nil,
extraInfo = S(extra)
)
tupleRef =>
Instructions.`if`(
condition = i32.lt_s(idxI32, i32.const(0)),
ifTrue = i32.add(idxI32, array.len(tupleRef)),
ifFalse = S(idxI32),
resultTypes = Seq(Result(I32Type))
)

/**
* Raises a [[WarningReport]] with the given `warnMsgs` and `extraInfo`, and emits an
* `unreachable` instruction.
Expand Down Expand Up @@ -237,25 +312,47 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:

case sel @ Select(qual, id) =>
val qualRes = result(qual)
val selSym = sel.symbol getOrElse:
lastWords(s"Symbol for Select(...) expression must be resolved")
val selTrmSym = selSym match
case termSym: TermSymbol => termSym
case sym => lastWords(
s"Expected resolved Select(...) expression to be a TermSymbol, but got $sym (${sym.getClass.getName})"
sel.symbol match
case S(selSym: TermSymbol) =>
val selOwner = selSym.owner getOrElse:
lastWords(s"Expected resolved Select(...) expression `$selSym` to have an owner")
val selCls = selOwner.asBlkMember getOrElse:
lastWords(
s"Expected resolved class for Select(...) expression to be a BlockMemberSymbol, but got $selOwner (${selOwner.getClass.getName})"
)
val fieldidx = fieldSelect(selCls, selSym)
struct.get(
fieldidx,
ref = ref.cast(qualRes, RefType(ctx.getType_!(selCls), nullable = false)),
ty = RefType.anyref
)
val selOwner = selTrmSym.owner getOrElse:
lastWords(s"Expected resolved Select(...) expression `$selTrmSym` to have an owner")
val selCls = selOwner.asBlkMember getOrElse:
lastWords(
s"Expected resolved class for Select(...) expression to be a BlockMemberSymbol, but got $selOwner (${selOwner.getClass.getName})"
case S(otherSym) =>
lastWords(
s"Expected resolved Select(...) expression to be a TermSymbol, but got $otherSym (${otherSym.getClass.getName})"
)
case N =>
errExpr(
Ls(
msg"WatBuilder::result for field selection without a resolved symbol is not implemented (field `${id.name}`). Use `_.[_]` for index-based accesses." -> sel.toLoc
),
extraInfo = S(sel)
)

case dyn @ DynSelect(qual, fld, arrayIdx) =>
val qualRes = result(qual)
if arrayIdx then
val idxBuilder = compileTupleIndex(
fld = fld,
loc = fld.toLoc,
errCtx = "WatBuilder::result for array-style dynamic selections",
extra = dyn.toString
)
tupleArrayGet(qualRes, idxBuilder)
else
errExpr(
Ls(msg"WatBuilder::result for dynamic field selections is not implemented yet" -> dyn.toLoc),
extraInfo = S(dyn)
)
val fieldidx = fieldSelect(selCls, selSym)
struct.get(
fieldidx,
ref = ref.cast(qualRes, RefType(ctx.getType_!(selCls), nullable = false)),
ty = RefType.anyref
)

case Instantiate(_, cls, as) =>
val ctorClsSymOpt = cls match
Expand Down Expand Up @@ -287,6 +384,10 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
val objType = ctx.getFuncInfo_!(ctorFuncIdx).body.resultType_!
call(funcidx = ctorFuncIdx, as.map(argument), Seq(Result(objType.asValType_!)))

case Tuple(mut, elems) =>
val tupleValues = elems.map(argument)
array.new_fixed(tupleArrayType(mut), tupleValues)

case r =>
errExpr(
Ls(msg"WatBackend::result for expression not implemented yet" -> r.toLoc),
Expand Down Expand Up @@ -430,13 +531,73 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
lastWords(
s"Expected `global.*` or `local.*` when compiling instruction for `$l`, but got ${lExpr.mnemonic}"
)
val rstBlk = returningTerm(rst)

val rstBlk = returningTerm(rst)
Instructions.block(
label = N,
children = Seq(assignExpr, rstBlk),
resultTypes = rstBlk.resultTypes.map: ty =>
Result(if ty is UnreachableType then RefType.anyref else ty.asValType_!)
resultTypes = rstBlk.resultTypes.map(r => Result(r.asValType_!))
)

case assign @ AssignField(lhs, nme, rhs, rst) =>
val lhsExpr = result(lhs)
val rhsExpr = result(rhs)
val assignInstr = assign.symbol match
case S(selSym: TermSymbol) =>
val selOwner = selSym.owner getOrElse
lastWords(s"Expected resolved AssignField(...) expression `$selSym` to have an owner")
val selCls = selOwner.asBlkMember getOrElse
lastWords(
s"Expected resolved class for AssignField(...) expression to be a BlockMemberSymbol, but got $selOwner (${selOwner.getClass.getName})"
)
val fieldidx = fieldSelect(selCls, selSym)
val objRef = ref.cast(lhsExpr, RefType(ctx.getType_!(selCls), nullable = false))
struct.set(fieldidx, objRef, rhsExpr)
case S(otherSym) =>
lastWords(
s"Expected resolved AssignField(...) expression to be a TermSymbol, but got $otherSym (${otherSym.getClass.getName})"
)
case N =>
errExpr(
Ls(
msg"WatBuilder::returningTerm for AssignField(...) without a resolved symbol is not implemented (field `${nme.name}`). Use `_.[_]` for index-based accesses." -> nme.toLoc
),
extraInfo = S(assign)
)

val rstBlk = returningTerm(rst)
Instructions.block(
label = N,
children = Seq(assignInstr, rstBlk),
resultTypes = rstBlk.resultTypes.map(r => Result(r.asValType_!))
)

case assign @ AssignDynField(lhs, fld, arrayIdx, rhs, rst) =>
val lhsExpr = result(lhs)
val rhsExpr = result(rhs)
val assignInstr =
if arrayIdx then
val tupleArrayType = this.tupleArrayType(mut = true)
val tupleRef = ref.cast(lhsExpr, RefType(tupleArrayType, nullable = false))
val idxBuilder = compileTupleIndex(
fld = fld,
loc = fld.toLoc,
errCtx = "WatBuilder::returningTerm for AssignDynField(...)",
extra = assign.toString
)
val idxExpr = idxBuilder(tupleRef)
array.set(tupleArrayType, tupleRef, idxExpr, rhsExpr)
else
errExpr(
Ls(msg"WatBuilder::returningTerm for AssignDynField(...) where `arrayIdx = false` is not implemented yet" -> lhs.toLoc),
extraInfo = S(assign)
)

val rstBlk = returningTerm(rst)
Instructions.block(
label = N,
children = Seq(assignInstr, rstBlk),
resultTypes = rstBlk.resultTypes.map(r => Result(r.asValType_!))
)

case Define(defn, rst) =>
Expand Down Expand Up @@ -787,6 +948,32 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
ifFalse = N,
resultTypes = Seq.empty
))
case Case.Tup(len, inf) =>
val arrayRefType = RefType(HeapType.Array, nullable = true)
val isArrayTest = ref.test(getScrutExpr, arrayRefType)

// Length check
val scrutArray = ref.cast(getScrutExpr, arrayRefType)
val arrayLength = array.len(scrutArray)
val lengthTest = if inf then
i32.ge_u(arrayLength, i32.const(len))
else
i32.eq(arrayLength, i32.const(len))

val testExpr = i32.and(isArrayTest, lengthTest)
val bodyExpr = returningTerm(body)
val armLabelSym = TempSymbol(N, "arm")
val armLabel = scope.allocateName(armLabelSym)
S(Instructions.`if`(
condition = testExpr,
ifTrue = Instructions.block(
label = S(armLabel),
children = Seq(bodyExpr, br(matchLabel)),
resultTypes = Seq.empty
),
ifFalse = N,
resultTypes = Seq.empty
))
case _ =>
break(errExpr(
Ls(
Expand Down
Loading
Loading