Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ case class Config(
target: CompilationTarget,
rewriteWhileLoops: Bool,
tailRecOpt: Bool,
qqEnabled: Bool,
):

def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
Expand Down Expand Up @@ -52,6 +53,7 @@ object Config:
rewriteWhileLoops = false,
stageCode = false,
tailRecOpt = true,
qqEnabled = false,
)

case class SanityChecks(light: Bool)
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ sealed abstract class Block extends Product:
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) => rst.definedVars + res
case TryBlock(sub, fin, rst) => sub.definedVars ++ fin.definedVars ++ rst.definedVars
case Label(lbl, _, bod, rst) => bod.definedVars ++ rst.definedVars
case Scoped(syms, body) => body.definedVars
case Scoped(syms, body) => body.definedVars ++ syms

lazy val size: Int = this match
case _: Return | _: Throw | _: End | _: Break | _: Continue => 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class BufferableTransform()(using Ctx, State, Raise):
def mkFieldReplacer(buf: Local, baseIdx: Local) =
def getOffset(off: Int)(k: Path => Block): Block =
val idxSymbol = new TempSymbol(N, "idx")
Assign(idxSymbol, Call(State.builtinOpsMap("+").asPath, baseIdx.asPath.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil)(true, false, false),
k(DynSelect(buf.asPath.selSN("buf"), idxSymbol.asPath, true)))
Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asPath, baseIdx.asPath.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil)(true, false, false),
k(DynSelect(buf.asPath.selSN("buf"), idxSymbol.asPath, true))))
def assignToOffset(off: Int, r: Result, rst: Block) =
val idxSymbol = new TempSymbol(N, "idx")
Assign(idxSymbol, Call(State.builtinOpsMap("+").asPath, baseIdx.asPath.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil)(true, false, false),
AssignDynField(buf.asPath.selSN("buf"), idxSymbol.asPath, true, r, applyBlock(rst)))
Scoped(Set.single(idxSymbol), Assign(idxSymbol, Call(State.builtinOpsMap("+").asPath, baseIdx.asPath.asArg :: Value.Lit(Tree.IntLit(off)).asArg :: Nil)(true, false, false),
AssignDynField(buf.asPath.selSN("buf"), idxSymbol.asPath, true, r, applyBlock(rst))))
new BlockTransformer(SymbolSubst()):
override def applyBlock(b: Block): Block = b match
case Assign(l, r, rst) =>
Expand Down
43 changes: 1 addition & 42 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,6 @@ import semantics.Elaborator.ctx
import semantics.Elaborator.State
import hkmc2.Config.EffectHandlers


/** - For function bodies, fuse all shallowly-nested scopes into one top-level one,
* because handler lowering relies on knowing all local variables in the function.
* - Assert the absence of Label(loop = true) blocks,
* because loops should be rewritten to functions first,
* otherwise we cannot fuse scopes correctly.
*/
class PreHandlerLowering extends BlockTransformer(new SymbolSubst):
override def applyBlock(b: Block): Block = b match
case Label(_, loop, _, _) =>
assert(!loop)
super.applyBlock(b)
case _ => super.applyBlock(b)

private var scopedSymForCurrentFun: Option[collection.mutable.Set[Symbol]] = None
override def applyFunBodyLikeBlock(b: Block): Block =
val prevScopedSymForCurrentFun = scopedSymForCurrentFun
val resBlk = b match
case Scoped(syms, body) =>
scopedSymForCurrentFun = Some(collection.mutable.Set.from(syms))
val newBody = applySubBlock(body)
new Scoped(scopedSymForCurrentFun.get, newBody)
case _ =>
scopedSymForCurrentFun = Some(collection.mutable.Set.empty[Symbol])
val newBlk = applySubBlock(b)
Scoped(scopedSymForCurrentFun.get, newBlk)
scopedSymForCurrentFun = prevScopedSymForCurrentFun
resBlk

override def applyScopedBlock(b: Block): Block = b match
case Scoped(syms, body) =>
scopedSymForCurrentFun match
case None => super.applyScopedBlock(b)
case Some(scopedForCurrentFun) =>
scopedForCurrentFun.addAll(syms)
super.applySubBlock(body)
case _ => super.applySubBlock(b)



object HandlerLowering:

private final val getLocalsNme = "getLocals"
Expand Down Expand Up @@ -996,7 +956,6 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,

def translateTopLevel(b: Block): (Block, Map[FnOrCls, Path]) =
doUnwindMap = Map.empty
val preTransformed = new PreHandlerLowering().applyBlock(b)
val transformed = translateBlock(preTransformed, Set.empty, N, L(BlockMemberSymbol("", Nil)), topLevelCtx(s"Cont$$topLevel$$BAD", "‹top level›"))
val transformed = translateBlock(b, Set.empty, N, L(BlockMemberSymbol("", Nil)), topLevelCtx(s"Cont$$topLevel$$BAD", "‹top level›"))
(transformed, doUnwindMap)

43 changes: 32 additions & 11 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import hkmc2.codegen.llir.FreshInt

import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.Map as MutMap
import scala.collection.mutable.Set as MutSet

object Lifter:
/**
Expand Down Expand Up @@ -622,7 +623,11 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
var activeClosures: Set[Local] = Set.empty
// Map from block member symbols to initialized closures
val closureMap: MutMap[BlockMemberSymbol, Local] = MutMap.empty
val extraLocals: MutSet[Local] = MutSet.empty

def rewrite(b: Block) =
val ret = applyBlock(b)
Scoped(extraLocals, ret)

// Replaces references to BlockMemberSymbols as needed with fresh variables, and
// returns the mapping from the symbol to the required variable. When possible,
Expand All @@ -634,6 +639,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
def rewriteBms(b: Block) =
// BMS's that need to be created
val syms: LinkedHashMap[FunSyms[?], Local] = LinkedHashMap.empty
val extraLocals: MutSet[Local] = MutSet.empty

val walker = new BlockDataTransformer(SymbolSubst()):
// only scan within the block. don't traverse
Expand Down Expand Up @@ -673,6 +679,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
// $this was previously used, but it may be confused with the `this` keyword
// let's use $here instead
val newSym = TempSymbol(N, l.nme + "$here")
extraLocals.add(newSym)
syms.addOne(FunSyms(l, d) -> newSym) // add to `syms`: this closure will be initialized in `applyBlock`
closureMap.addOne(l -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later
newSym
Expand All @@ -685,7 +692,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
value
k(Value.Ref(newSym, S(d)))
case _ => super.applyPath(p)(k)
(walker.applyBlock(b), syms.toList)
(walker.applyBlock(b), syms.toList, extraLocals)
end rewriteBms

def applySubBlockAndReset(b: Block): Block =
Expand All @@ -697,7 +704,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
override def applyBlock(b: Block): Block =
// extract references to BlockMemberSymbols in the block which now may
// need to be enriched with aux parameters
val (rewritten, syms) = rewriteBms(b)
val (rewritten, syms, extras) = rewriteBms(b)
extraLocals.addAll(extras)
val pre = syms.foldLeft(blockBuilder):
case (blk, (bms, local)) =>
val initial = blk.assign(local, createCall(bms, ctx))
Expand Down Expand Up @@ -777,6 +785,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
case Define(d: ClsLikeDefn, rest: Block) => ctx.modObjLocals.get(d.sym) match
case Some(sym) if !ctx.ignored(d.sym) => ctx.getBmsReqdInfo(d.sym) match
case Some(_) => // has args
extraLocals.add(sym)
blockBuilder
.assign(sym, Instantiate(mut = false, d.sym.asPath, getCallArgs(FunSyms(d.sym, d.isym), ctx)))
.rest(applyBlock(rest))
Expand Down Expand Up @@ -1030,7 +1039,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):

val isMutSym = VarSymbol(Tree.Ident("isMut"))

val curSyms: MutSet[Local] = MutSet.empty
var curSym = TempSymbol(None, "tmp")
curSyms.add(curSym)
def instInner(isMut: Bool) =
Instantiate(mut = isMut, Value.Ref(c.sym, S(c.isym)), paramArgs)

Expand All @@ -1046,10 +1057,11 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
for ps <- newAuxSyms do
val call = Call(curSym.asPath, ps.map(_.asPath.asArg))(true, false, false)
curSym = TempSymbol(None, "tmp")
curSyms.add(curSym)
val thisSym = curSym
acc = acc.assign(thisSym, call)
// acc = blk => acc(Assign(curSym, call, blk))
val bod = acc.ret(curSym.asPath)
val bod = Scoped(curSyms, acc.ret(curSym.asPath))

inline def toPlist(ls: List[VarSymbol]) =
PlainParamList(ls.map(s => Param(FldFlags.empty, s, N, Modulefulness.none)))
Expand Down Expand Up @@ -1096,6 +1108,10 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):

end liftOutDefnCont

def removeDefnsFromScope(b: Block, defns: List[Defn]) = b match
case Scoped(syms, body) => Scoped(syms.toSet -- defns.map(_.sym), body)
case _ => b

def liftDefnsInCls(c: ClsLikeDefn, ctx: LifterCtx): Lifted[ClsLikeDefn] =
val ctxx = if c.companion.isDefined then ctx.inModule(c) else ctx // TODO: refine handling of companions

Expand Down Expand Up @@ -1162,9 +1178,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):

val replacedDefnsCtx = newCtx.addreplacedDefns(ctorIgnoredRewrite)
val rewriter = BlockRewriter(newCtx.inScopeISyms, replacedDefnsCtx)
val newPreCtor = rewriter.applyBlock(preCtor)
val newCtor = rewriter.applyBlock(ctor)
val newCCtor = cCtor.map(rewriter.applyBlock(_))
val newPreCtor = removeDefnsFromScope(rewriter.rewrite(preCtor), ctorIncluded)
val newCtor = removeDefnsFromScope(rewriter.rewrite(ctor), ctorIncluded)
val newCCtor = cCtor.map(blk => removeDefnsFromScope(rewriter.rewrite(blk), ctorIncluded))

// ===========================================================
// STEP 2: rewrite non-static class methods
Expand Down Expand Up @@ -1251,10 +1267,11 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
lifted.liftedDefn.sym -> lifted.liftedDefn
.toMap

val transformed = BlockRewriter(ctx.inScopeISyms, captureCtx.addreplacedDefns(ignoredRewrite)).applyBlock(blk)
val transformed = BlockRewriter(ctx.inScopeISyms, captureCtx.addreplacedDefns(ignoredRewrite)).rewrite(blk)
val newScopedBlk = removeDefnsFromScope(transformed, included)

if thisVars.reqCapture.size == 0 then
Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, transformed)(forceTailRec = f.forceTailRec), newDefns)
Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, newScopedBlk)(forceTailRec = f.forceTailRec), newDefns)
else
// move the function's parameters to the capture
val paramsSet = f.params.flatMap(_.paramSyms)
Expand All @@ -1264,8 +1281,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val bod = blockBuilder
.assign(captureSym, Instantiate(mut = true, // * Note: `mut` is needed for capture classes
captureCls.sym.asPath, paramsList))
.rest(transformed)
Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, bod)(forceTailRec = f.forceTailRec), captureCls :: newDefns)
.rest(newScopedBlk)
val withScope = Scoped(Set(captureSym), bod)
Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, withScope)(forceTailRec = f.forceTailRec), captureCls :: newDefns)

end liftDefnsInFn

Expand Down Expand Up @@ -1319,6 +1337,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val ctxxx = ctxx.withDefnsCur(analyzer.nestedDeep(d.sym))
liftDefnsInCls(c, ctxxx.addBmsReqdInfo(createLiftInfoCls(c, ctxxx)))
case _ => return super.applyBlock(b)
(lifted :: extra).foldLeft(applyBlock(rest))((acc, defn) => Define(defn, acc))
val newDefns = lifted :: extra
val newBms = newDefns.map(_.sym)
val newBlk = newDefns.foldLeft(applyBlock(rest))((acc, defn) => Define(defn, acc))
Scoped(newBms.toSet, newBlk)
case _ => super.applyBlock(b)
walker1.applyBlock(blk)
Loading
Loading