From 5f62371ed1e45ac52acafa9052b46a9d9655c938 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 1 Mar 2026 17:30:02 +0800 Subject: [PATCH] chore: TCO for tailstrict --- sjsonnet/src/sjsonnet/Error.scala | 18 +- sjsonnet/src/sjsonnet/Evaluator.scala | 155 +++++- sjsonnet/src/sjsonnet/Expr.scala | 45 +- sjsonnet/src/sjsonnet/Materializer.scala | 19 +- sjsonnet/src/sjsonnet/Val.scala | 111 ++++- .../src/sjsonnet/stdlib/ArrayModule.scala | 4 +- .../sjsonnet/TailCallOptimizationTests.scala | 452 ++++++++++++++++++ 7 files changed, 752 insertions(+), 52 deletions(-) create mode 100644 sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala diff --git a/sjsonnet/src/sjsonnet/Error.scala b/sjsonnet/src/sjsonnet/Error.scala index b5efa8b1..0bb7a70b 100644 --- a/sjsonnet/src/sjsonnet/Error.scala +++ b/sjsonnet/src/sjsonnet/Error.scala @@ -17,16 +17,20 @@ class Error(msg: String, stack: List[Error.Frame] = Nil, underlying: Option[Thro def addFrame(pos: Position, expr: Expr = null)(implicit ev: EvalErrorScope): Error = { if (stack.isEmpty || alwaysAddPos(expr)) { val exprErrorString = if (expr == null) null else expr.exprErrorString - val newFrame = new Error.Frame(pos, exprErrorString) - stack match { - case s :: ss if s.pos == pos => - if (s.exprErrorString == null && exprErrorString != null) copy(stack = newFrame :: ss) - else this - case _ => copy(stack = newFrame :: stack) - } + addFrameString(pos, exprErrorString) } else this } + def addFrameString(pos: Position, exprErrorString: String)(implicit ev: EvalErrorScope): Error = { + val newFrame = new Error.Frame(pos, exprErrorString) + stack match { + case s :: ss if s.pos == pos => + if (s.exprErrorString == null && exprErrorString != null) copy(stack = newFrame :: ss) + else this + case _ => copy(stack = newFrame :: stack) + } + } + def asSeenFrom(ev: EvalErrorScope): Error = copy(stack = stack.map(_.asSeenFrom(ev))) diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 762cf1c8..0428f2fe 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -202,13 +202,22 @@ class Evaluator( } } + /** + * Function application entry points (visitApply/visitApply0-3 for user functions, + * visitApplyBuiltin/visitApplyBuiltin0-4 for built-in functions). + * + * When `e.tailstrict` is true, the result is wrapped in `TailCall.resolve()` which iteratively + * resolves any [[TailCall]] chain. When false, arguments are wrapped as lazy thunks to preserve + * Jsonnet's default lazy evaluation semantics, and `Val.Func.apply` resolves any TailCall + * internally via `TailCall.resolve` before returning. + */ protected def visitApply(e: Apply)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) implicit val tailstrictMode: TailstrictMode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled if (e.tailstrict) { - lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos) + TailCall.resolve(lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)) } else { lhs.cast[Val.Func].apply(e.args.map(visitAsLazy(_)), e.namedNames, e.pos) } @@ -218,7 +227,11 @@ class Evaluator( val lhs = visitExpr(e.value) implicit val tailstrictMode: TailstrictMode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled - lhs.cast[Val.Func].apply0(e.pos) + if (e.tailstrict) { + TailCall.resolve(lhs.cast[Val.Func].apply0(e.pos)) + } else { + lhs.cast[Val.Func].apply0(e.pos) + } } protected def visitApply1(e: Apply1)(implicit scope: ValScope): Val = { @@ -226,7 +239,7 @@ class Evaluator( implicit val tailstrictMode: TailstrictMode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled if (e.tailstrict) { - lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos) + TailCall.resolve(lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)) } else { val l1 = visitAsLazy(e.a1) lhs.cast[Val.Func].apply1(l1, e.pos) @@ -239,7 +252,7 @@ class Evaluator( if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled if (e.tailstrict) { - lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos) + TailCall.resolve(lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)) } else { val l1 = visitAsLazy(e.a1) val l2 = visitAsLazy(e.a2) @@ -253,7 +266,9 @@ class Evaluator( if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled if (e.tailstrict) { - lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + TailCall.resolve( + lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + ) } else { val l1 = visitAsLazy(e.a1) val l2 = visitAsLazy(e.a2) @@ -262,11 +277,14 @@ class Evaluator( } } - protected def visitApplyBuiltin0(e: ApplyBuiltin0): Val = e.func.evalRhs(this, e.pos) + protected def visitApplyBuiltin0(e: ApplyBuiltin0): Val = { + val result = e.func.evalRhs(this, e.pos) + if (e.tailstrict) TailCall.resolve(result) else result + } protected def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope): Val = { if (e.tailstrict) { - e.func.evalRhs(visitExpr(e.a1), this, e.pos) + TailCall.resolve(e.func.evalRhs(visitExpr(e.a1), this, e.pos)) } else { e.func.evalRhs(visitAsLazy(e.a1), this, e.pos) } @@ -274,7 +292,7 @@ class Evaluator( protected def visitApplyBuiltin2(e: ApplyBuiltin2)(implicit scope: ValScope): Val = { if (e.tailstrict) { - e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), this, e.pos) + TailCall.resolve(e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), this, e.pos)) } else { e.func.evalRhs(visitAsLazy(e.a1), visitAsLazy(e.a2), this, e.pos) } @@ -282,7 +300,9 @@ class Evaluator( protected def visitApplyBuiltin3(e: ApplyBuiltin3)(implicit scope: ValScope): Val = { if (e.tailstrict) { - e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), this, e.pos) + TailCall.resolve( + e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), this, e.pos) + ) } else { e.func.evalRhs(visitAsLazy(e.a1), visitAsLazy(e.a2), visitAsLazy(e.a3), this, e.pos) } @@ -290,13 +310,15 @@ class Evaluator( protected def visitApplyBuiltin4(e: ApplyBuiltin4)(implicit scope: ValScope): Val = { if (e.tailstrict) { - e.func.evalRhs( - visitExpr(e.a1), - visitExpr(e.a2), - visitExpr(e.a3), - visitExpr(e.a4), - this, - e.pos + TailCall.resolve( + e.func.evalRhs( + visitExpr(e.a1), + visitExpr(e.a2), + visitExpr(e.a3), + visitExpr(e.a4), + this, + e.pos + ) ) } else { e.func.evalRhs( @@ -319,7 +341,7 @@ class Evaluator( arr(idx) = visitExpr(e.argExprs(idx)) idx += 1 } - e.func.evalRhs(arr, this, e.pos) + TailCall.resolve(e.func.evalRhs(arr, this, e.pos)) } else { while (idx < e.argExprs.length) { val boundIdx = idx @@ -638,10 +660,107 @@ class Evaluator( scope: ValScope): Val.Func = new Val.Func(outerPos, scope, params) { def evalRhs(vs: ValScope, es: EvalScope, fs: FileScope, pos: Position): Val = - visitExpr(rhs)(vs) + visitExprWithTailCallSupport(rhs)(vs) override def evalDefault(expr: Expr, vs: ValScope, es: EvalScope): Val = visitExpr(expr)(vs) } + /** + * Evaluate an expression with tail-call support. When a `tailstrict` call is encountered at a + * potential tail position, returns a [[TailCall]] sentinel instead of recursing, enabling + * `TailCall.resolve` in `visitApply*` to iterate rather than grow the JVM stack. + * + * Potential tail positions are propagated through: IfElse (both branches), LocalExpr (returned), + * and AssertExpr (returned). All other expression types delegate to normal `visitExpr`. + */ + @tailrec + private def visitExprWithTailCallSupport(e: Expr)(implicit scope: ValScope): Val = e match { + case e: IfElse => + visitExpr(e.cond) match { + case Val.True(_) => visitExprWithTailCallSupport(e.`then`) + case Val.False(_) => + e.`else` match { + case null => Val.Null(e.pos) + case v => visitExprWithTailCallSupport(v) + } + case v => Error.fail("Need boolean, found " + v.prettyName, e.pos) + } + case e: LocalExpr => + val bindings = e.bindings + val s = + if (bindings == null) scope + else { + val base = scope.length + val newScope = scope.extendBy(bindings.length) + var i = 0 + while (i < bindings.length) { + val b = bindings(i) + newScope.bindings(base + i) = b.args match { + case null => visitAsLazy(b.rhs)(newScope) + case argSpec => + new Lazy(() => visitMethod(b.rhs, argSpec, b.pos)(newScope)) + } + i += 1 + } + newScope + } + visitExprWithTailCallSupport(e.returned)(s) + case e: AssertExpr => + if (!visitExpr(e.asserted.value).isInstanceOf[Val.True]) { + e.asserted.msg match { + case null => Error.fail("Assertion failed", e) + case msg => + Error.fail("Assertion failed: " + materializeError(visitExpr(msg)), e) + } + } + visitExprWithTailCallSupport(e.returned) + // Tail-position tailstrict calls: match TailstrictableExpr to unify the tailstrict guard, + // then dispatch by concrete type. + // + // - Apply* (user function calls): construct a TailCall sentinel that the caller's + // TailCall.resolve loop will resolve iteratively, avoiding JVM stack growth for + // tail-recursive calls. + // - ApplyBuiltin* (built-in function calls): fall through to visitExpr, which dispatches to + // visitApplyBuiltin*. Those methods already wrap their result in TailCall.resolve() when + // tailstrict=true, resolving any TailCall that a user-defined callback (e.g. the function + // argument to std.makeArray or std.sort) may have returned. + case e: TailstrictableExpr if e.tailstrict => + e match { + case e: Apply => + try { + val func = visitExpr(e.value).cast[Val.Func] + new TailCall(func, e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]], e.namedNames, e) + } catch Error.withStackFrame(e) + case e: Apply0 => + try { + val func = visitExpr(e.value).cast[Val.Func] + new TailCall(func, Evaluator.emptyLazyArray, null, e) + } catch Error.withStackFrame(e) + case e: Apply1 => + try { + val func = visitExpr(e.value).cast[Val.Func] + new TailCall(func, Array[Eval](visitExpr(e.a1)), null, e) + } catch Error.withStackFrame(e) + case e: Apply2 => + try { + val func = visitExpr(e.value).cast[Val.Func] + new TailCall(func, Array[Eval](visitExpr(e.a1), visitExpr(e.a2)), null, e) + } catch Error.withStackFrame(e) + case e: Apply3 => + try { + val func = visitExpr(e.value).cast[Val.Func] + new TailCall( + func, + Array[Eval](visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3)), + null, + e + ) + } catch Error.withStackFrame(e) + case _ => visitExpr(e) + } + case _ => + visitExpr(e) + } + def visitBindings(bindings: Array[Bind], scope: => ValScope): Array[Eval] = { val arrF = new Array[Eval](bindings.length) var i = 0 diff --git a/sjsonnet/src/sjsonnet/Expr.scala b/sjsonnet/src/sjsonnet/Expr.scala index 211958a1..943ec161 100644 --- a/sjsonnet/src/sjsonnet/Expr.scala +++ b/sjsonnet/src/sjsonnet/Expr.scala @@ -23,6 +23,27 @@ trait Expr { override def toString: String = s"$exprErrorString@$pos" } + +/** + * Marker trait for [[Expr]] nodes that represent function calls eligible for tail-call + * optimization. All Apply* (user function calls) and ApplyBuiltin* (built-in function calls) mix in + * this trait, providing a uniform `tailstrict` flag. The evaluator handles the two families + * differently when `tailstrict` is true: + * + * - '''User function calls''' (Apply*) in tail position: the evaluator constructs a [[TailCall]] + * sentinel and returns it to the caller's [[TailCall.resolve]] trampoline loop, avoiding JVM + * stack growth for tail-recursive calls. + * - '''Built-in function calls''' (ApplyBuiltin*): the evaluator wraps the result in + * [[TailCall.resolve]] at the call site, resolving any [[TailCall]] that a user-defined + * callback (e.g. the function argument to `std.makeArray` or `std.sort`) may have returned. + * + * @see + * [[TailCall]] for the sentinel value used in the TCO protocol + */ +trait TailstrictableExpr extends Expr { + def tailstrict: Boolean +} + object Expr { private final def arrStr(a: Array[?]): String = { if (a == null) "null" else a.mkString("[", ", ", "]") @@ -189,17 +210,19 @@ object Expr { args: Array[Expr], namedNames: Array[String], tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply } - final case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr { + final case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply0 } - final case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr { + final case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply1 } final case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply2 } final case class Apply3( @@ -209,7 +232,7 @@ object Expr { a2: Expr, a3: Expr, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply3 } final case class ApplyBuiltin( @@ -217,17 +240,17 @@ object Expr { func: Val.Builtin, argExprs: Array[Expr], tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin override def exprErrorString: String = s"std.${func.functionName}" } final case class ApplyBuiltin0(pos: Position, func: Val.Builtin0, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin0 override def exprErrorString: String = s"std.${func.functionName}" } final case class ApplyBuiltin1(pos: Position, func: Val.Builtin1, a1: Expr, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin1 override def exprErrorString: String = s"std.${func.functionName}" } @@ -237,7 +260,7 @@ object Expr { a1: Expr, a2: Expr, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin2 override def exprErrorString: String = s"std.${func.functionName}" } @@ -248,7 +271,7 @@ object Expr { a2: Expr, a3: Expr, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin3 override def exprErrorString: String = s"std.${func.functionName}" } @@ -260,7 +283,7 @@ object Expr { a3: Expr, a4: Expr, tailstrict: Boolean) - extends Expr { + extends TailstrictableExpr { override private[sjsonnet] def tag = ExprTags.ApplyBuiltin4 override def exprErrorString: String = s"std.${func.functionName}" } diff --git a/sjsonnet/src/sjsonnet/Materializer.scala b/sjsonnet/src/sjsonnet/Materializer.scala index 0661d997..46694eaf 100644 --- a/sjsonnet/src/sjsonnet/Materializer.scala +++ b/sjsonnet/src/sjsonnet/Materializer.scala @@ -54,16 +54,23 @@ abstract class Materializer { i += 1 } arrVisitor.visitEnd(-1) - case Val.True(pos) => storePos(pos); visitor.visitTrue(-1) - case Val.False(pos) => storePos(pos); visitor.visitFalse(-1) - case Val.Null(pos) => storePos(pos); visitor.visitNull(-1) - case s: Val.Func => + case Val.True(pos) => storePos(pos); visitor.visitTrue(-1) + case Val.False(pos) => storePos(pos); visitor.visitFalse(-1) + case Val.Null(pos) => storePos(pos); visitor.visitNull(-1) + case mat: Materializer.Materializable => storePos(v.pos); mat.materialize(visitor) + case s: Val.Func => Error.fail( "Couldn't manifest function with params [" + s.params.names.mkString(",") + "]", v.pos ) - case mat: Materializer.Materializable => storePos(v.pos); mat.materialize(visitor) - case vv: Val => + case tc: TailCall => + Error.fail( + "Internal error: TailCall sentinel leaked into materialization. " + + "This indicates a bug in the TCO protocol — a TailCall was not resolved before " + + "reaching the Materializer.", + tc.pos + ) + case vv: Val => Error.fail("Unknown value type " + vv.prettyName, vv.pos) case null => Error.fail("Unknown value type " + v) diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 7ea602d2..17be8e18 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -664,6 +664,12 @@ object Val { def evalRhs(scope: ValScope, ev: EvalScope, fs: FileScope, pos: Position): Val + // Convenience wrapper: evaluates the function body and resolves any TailCall sentinel. + // Use this instead of raw `evalRhs` at call sites that bypass `apply*` and consume + // the result directly (e.g. stdlib scope-reuse fast paths). + final def evalRhsResolved(scope: ValScope, ev: EvalScope, fs: FileScope, pos: Position): Val = + TailCall.resolve(evalRhs(scope, ev, fs, pos))(ev) + def evalDefault(expr: Expr, vs: ValScope, es: EvalScope): Val = null def prettyName = "function" @@ -672,6 +678,15 @@ object Val { override def asFunc: Func = this + /** + * Core function application with tail call optimization (TCO) support. + * + * TCO protocol: when `tailstrictMode == TailstrictModeEnabled`, `evalRhs` may return a + * [[TailCall]] sentinel which is propagated back to the caller's [[TailCall.resolve]] loop + * without resolution. When `tailstrictMode == TailstrictModeDisabled` (the common case — called + * from std library, object fields, etc.), any TailCall is resolved here via `TailCall.resolve` + * to prevent sentinel leakage. + */ def apply(argsL: Array[? <: Eval], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, tailstrictMode: TailstrictMode): Val = { @@ -680,13 +695,13 @@ object Val { case null => outerPos.fileScope case p => p.fileScope } - // println(s"apply: argsL: ${argsL.length}, namedNames: $namedNames, paramNames: ${params.names.mkString(",")}") if (simple) { if (tailstrictMode == TailstrictModeEnabled) { argsL.foreach(_.value) } val newScope = defSiteValScope.extendSimple(argsL) - evalRhs(newScope, ev, funDefFileScope, outerPos) + val result = evalRhs(newScope, ev, funDefFileScope, outerPos) + if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result } else { val newScopeLen = math.max(params.names.length, argsL.length) // Initialize positional args @@ -743,10 +758,16 @@ object Val { if (tailstrictMode == TailstrictModeEnabled) { argVals.foreach(_.value) } - evalRhs(newScope, ev, funDefFileScope, outerPos) + val result = evalRhs(newScope, ev, funDefFileScope, outerPos) + if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result } } + // apply0–apply3: fast paths for the most common call arities, called from + // Evaluator.visitApply0–visitApply3. When the arity matches exactly and there are + // no named/default arguments, these skip the general-purpose scope-extension logic + // in `apply` (named-arg mapping, defaults filling, arraycopy) and use the cheaper + // `ValScope.extendSimple` instead. def apply0(outerPos: Position)(implicit ev: EvalScope, tailstrictMode: TailstrictMode): Val = { if (params.names.length != 0) apply(Evaluator.emptyLazyArray, null, outerPos) else { @@ -754,7 +775,8 @@ object Val { case null => outerPos.fileScope case p => p.fileScope } - evalRhs(defSiteValScope, ev, funDefFileScope, outerPos) + val result = evalRhs(defSiteValScope, ev, funDefFileScope, outerPos) + if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result } } @@ -771,7 +793,8 @@ object Val { argVal.value } val newScope: ValScope = defSiteValScope.extendSimple(argVal) - evalRhs(newScope, ev, funDefFileScope, outerPos) + val result = evalRhs(newScope, ev, funDefFileScope, outerPos) + if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result } } @@ -789,7 +812,8 @@ object Val { argVal2.value } val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2) - evalRhs(newScope, ev, funDefFileScope, outerPos) + val result = evalRhs(newScope, ev, funDefFileScope, outerPos) + if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result } } @@ -808,12 +832,22 @@ object Val { argVal3.value } val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3) - evalRhs(newScope, ev, funDefFileScope, outerPos) + val result = evalRhs(newScope, ev, funDefFileScope, outerPos) + if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result } } } - /** Superclass for standard library functions */ + /** + * Superclass for standard library functions. + * + * TCO note: the arity-specialized overrides (`apply1`–`apply3`) intentionally omit the + * `TailCall.resolve` guard present in [[Func.apply1]]–[[Func.apply3]]. This is safe because + * built-in `evalRhs` implementations are concrete Scala code that never produce [[TailCall]] + * sentinels directly. When a built-in internally invokes a user-defined callback (e.g. + * `std.makeArray`, `std.sort`), it passes `TailstrictModeDisabled` explicitly, so the callback's + * own `Val.Func.apply*` resolves any TailCall before returning. + */ abstract class Builtin( val functionName: String, paramNames: Array[String], @@ -837,6 +871,9 @@ object Val { def evalRhs(args: Array[? <: Eval], ev: EvalScope, pos: Position): Val + // No TailCall.resolve needed: Builtin evalRhs is pure Scala and never produces TailCall. + // When builtins invoke user callbacks internally, they pass TailstrictModeDisabled, + // so the callback's own Func.apply* resolves any TailCall before returning. override def apply1(argVal: Eval, outerPos: Position)(implicit ev: EvalScope, tailstrictMode: TailstrictMode): Val = @@ -991,10 +1028,68 @@ object Val { } } +/** + * Discriminator for the TCO protocol, passed as an implicit through the call chain. + * + * Using a sealed trait (rather than a plain Boolean) gives the JVM JIT better type-profile + * information at `if` guards, and makes the two modes self-documenting at call sites. + * + * - [[TailstrictModeEnabled]]: caller will handle TailCall via [[TailCall.resolve]]; sentinels + * may be returned without resolution. + * - [[TailstrictModeDisabled]]: normal call; any TailCall must be resolved before returning. + */ sealed trait TailstrictMode case object TailstrictModeEnabled extends TailstrictMode case object TailstrictModeDisabled extends TailstrictMode +/** + * Sentinel value for tail call optimization of `tailstrict` calls. When a function body's tail + * position is a `tailstrict` call, the evaluator returns a [[TailCall]] instead of recursing into + * the callee. [[TailCall.resolve]] then re-invokes the target function iteratively, eliminating + * native stack growth. + * + * This is an internal protocol value and must never escape to user-visible code paths (e.g. + * materialization, object field access). Every call site that may produce a TailCall must either + * pass `TailstrictModeEnabled` (so the caller resolves it) or guard the result with + * [[TailCall.resolve]]. + */ +final class TailCall( + val func: Val.Func, + val args: Array[Eval], + val namedNames: Array[String], + val callSiteExpr: Expr) + extends Val { + def pos: Position = callSiteExpr.pos + def prettyName = "tailcall" + def exprErrorString: String = callSiteExpr.exprErrorString +} + +object TailCall { + + /** + * Iteratively resolve a [[TailCall]] chain (trampoline loop). If `current` is not a TailCall, it + * is returned immediately. Otherwise, each TailCall's target function is re-invoked with + * `TailstrictModeEnabled` until a non-TailCall result is produced. + * + * Error frames preserve the original call-site expression name (e.g. "Apply2") so that TCO does + * not alter user-visible stack traces. + */ + @tailrec + def resolve(current: Val)(implicit ev: EvalScope): Val = current match { + case tc: TailCall => + implicit val tailstrictMode: TailstrictMode = TailstrictModeEnabled + val next = + try { + tc.func.apply(tc.args, tc.namedNames, tc.callSiteExpr.pos) + } catch { + case e: Error => + throw e.addFrame(tc.callSiteExpr.pos, tc.callSiteExpr) + } + resolve(next) + case result => result + } +} + /** * [[EvalScope]] models the per-evaluator context that is propagated throughout the Jsonnet * evaluation. diff --git a/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala b/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala index 8254c92c..27236022 100644 --- a/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala +++ b/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala @@ -127,13 +127,13 @@ object ArrayModule extends AbstractFunctionModule { val scopeIdx = newScope.length - 1 while (i < a.length) { newScope.bindings(scopeIdx) = a(i) - if (!func.evalRhs(newScope, ev, funDefFileScope, p).asBoolean) { + if (!func.evalRhsResolved(newScope, ev, funDefFileScope, p).asBoolean) { var b = new Array[Eval](a.length - 1) System.arraycopy(a, 0, b, 0, i) var j = i + 1 while (j < a.length) { newScope.bindings(scopeIdx) = a(j) - if (func.evalRhs(newScope, ev, funDefFileScope, p).asBoolean) { + if (func.evalRhsResolved(newScope, ev, funDefFileScope, p).asBoolean) { b(i) = a(j) i += 1 } diff --git a/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala b/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala new file mode 100644 index 00000000..addb9497 --- /dev/null +++ b/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala @@ -0,0 +1,452 @@ +package sjsonnet + +import utest._ +import TestUtils.{eval, evalErr} + +object TailCallOptimizationTests extends TestSuite { + val tests: Tests = Tests { + test("tailstrictFactorialSmall") { + eval( + """ + |local factorial(n, accum=1) = + | if n <= 1 then accum + | else factorial(n - 1, n * accum) tailstrict; + | + |factorial(10) + |""".stripMargin + ) ==> ujson.Num(3628800) + } + + test("tailstrictFactorialOverflow") { + // factorial(1000) overflows IEEE 754 double, sjsonnet should report overflow + val err = evalErr( + """ + |local factorial(n, accum=1) = + | if n <= 1 then accum + | else factorial(n - 1, n * accum) tailstrict; + | + |factorial(1000) + |""".stripMargin + ) + assert(err.contains("overflow")) + } + + test("tailstrictDeepRecursionSum") { + // Sum 1..10000 via tail-recursive accumulator — verifies TCO prevents stack overflow + eval( + """ + |local sum(n, accum=0) = + | if n <= 0 then accum + | else sum(n - 1, accum + n) tailstrict; + | + |local sz = 10000; + |std.assertEqual(sum(sz), sz * (sz + 1) / 2) + |""".stripMargin + ) ==> ujson.True + } + + test("tailstrictDeepRecursionCountdown") { + // 100000 recursive calls — would blow the JVM stack without TCO + eval( + """ + |local countdown(n) = + | if n <= 0 then 0 + | else countdown(n - 1) tailstrict; + | + |countdown(100000) + |""".stripMargin + ) ==> ujson.Num(0) + } + + test("tailstrictWithDefaultParams") { + // Verify tailstrict works correctly with default parameter values + eval( + """ + |local f(n, step=1, accum=0) = + | if n <= 0 then accum + | else f(n - step, accum=accum + n) tailstrict; + | + |f(100) + |""".stripMargin + ) ==> ujson.Num(5050) + } + + test("tailstrictMutuallyIndirect") { + // Tailstrict through if-else tail position propagation + eval( + """ + |local f(n, accum=0) = + | if n <= 0 then accum + | else if n % 2 == 0 then f(n - 1, accum + n) tailstrict + | else f(n - 1, accum + n) tailstrict; + | + |f(1000) + |""".stripMargin + ) ==> ujson.Num(500500) + } + + test("tailstrictThroughLocal") { + // Tailstrict call in tail position after local binding + eval( + """ + |local f(n, accum=0) = + | if n <= 0 then accum + | else + | local next = n - 1; + | local added = accum + n; + | f(next, added) tailstrict; + | + |f(10000) + |""".stripMargin + ) ==> ujson.Num(50005000) + } + + test("tailstrictThroughAssert") { + // Tailstrict call in tail position after assert + eval( + """ + |local f(n, accum=0) = + | assert n >= 0 : "n must be non-negative"; + | if n == 0 then accum + | else f(n - 1, accum + n) tailstrict; + | + |f(1000) + |""".stripMargin + ) ==> ujson.Num(500500) + } + + test("tailstrictBuiltinHigherOrder") { + // Verify that a builtin higher-order function (std.makeArray) correctly resolves + // TailCall produced by a user callback that uses tailstrict internally. + // std.makeArray calls the callback with TailstrictModeDisabled, so the callback's + // own Val.Func.apply* must resolve any TailCall before returning to the builtin. + eval( + """ + |local double(n, accum=0) = + | if n <= 0 then accum + | else double(n - 1, accum + 2) tailstrict; + | + |std.makeArray(5, function(i) double(i)) + |""".stripMargin + ) ==> ujson.Arr(ujson.Num(0), ujson.Num(2), ujson.Num(4), ujson.Num(6), ujson.Num(8)) + } + + test("tailstrictBuiltinFilterDirectTailstrict") { + // Regression: the predicate's function body is *directly* a tailstrict call + // (not wrapped in a non-tailstrict intermediate call). When std.filter's + // scope-reuse fast path calls evalRhs, visitExprWithTailCallSupport returns + // a TailCall sentinel because the outermost expression is `tailstrict`. + // Without evalRhsResolved, .asBoolean would fail on the TailCall sentinel. + eval( + """ + |local identity(x) = x; + |local pred(x) = identity(x > 0) tailstrict; + |std.filter(pred, [1, -1, 2, -3, 4]) + |""".stripMargin + ) ==> ujson.Arr(ujson.Num(1), ujson.Num(2), ujson.Num(4)) + } + + test("tailstrictBuiltinFilterDirectTailstrictAllPass") { + // All elements pass — exercises the first evalRhs call site (line 129) + // where the predicate body is directly a tailstrict call. + eval( + """ + |local identity(x) = x; + |local pred(x) = identity(x > 0) tailstrict; + |std.filter(pred, [1, 2, 3]) + |""".stripMargin + ) ==> ujson.Arr(ujson.Num(1), ujson.Num(2), ujson.Num(3)) + } + + test("tailstrictBuiltinFilterDirectTailstrictAllReject") { + // All elements rejected — exercises both call sites with a predicate + // whose body is directly a tailstrict call returning false. + eval( + """ + |local identity(x) = x; + |local pred(x) = identity(x < 0) tailstrict; + |std.filter(pred, [1, 2, 3]) + |""".stripMargin + ) ==> ujson.Arr() + } + test("tailstrictZeroArgs") { + // Apply0: zero-argument tailstrict call + eval( + """ + |local x() = 42; + |x() tailstrict + |""".stripMargin + ) ==> ujson.Num(42) + } + + test("tailstrictThreeArgs") { + // Apply3: three-argument tailstrict call with deep recursion + eval( + """ + |local f(n, a, b) = + | if n <= 0 then a + b + | else f(n - 1, a + 1, b + 1) tailstrict; + | + |f(10000, 0, 0) + |""".stripMargin + ) ==> ujson.Num(20000) + } + + test("tailstrictNamedArgs") { + // Apply with named arguments in tailstrict call + eval( + """ + |local f(n, accum=0) = + | if n <= 0 then accum + | else f(accum=accum + n, n=n - 1) tailstrict; + | + |f(100) + |""".stripMargin + ) ==> ujson.Num(5050) + } + + test("tailstrictEagerParamEvaluation") { + // tailstrict forces eager evaluation of arguments — error in unused param should trigger + val err = evalErr( + """ + |local f(x, y) = x; + |f(42, error "kaboom") tailstrict + |""".stripMargin + ) + assert(err.contains("kaboom")) + } + + test("nonTailstrictLazyParams") { + // Without tailstrict, unused error param should NOT trigger (lazy evaluation) + eval( + """ + |local f(x, y) = x; + |f(42, error "kaboom") + |""".stripMargin + ) ==> ujson.Num(42) + } + + test("tailstrictErrorStackFrame") { + // Errors inside tailstrict calls should preserve meaningful stack frames + val err = evalErr( + """ + |local f(n) = + | if n <= 0 then error "reached zero" + | else f(n - 1) tailstrict; + | + |f(3) + |""".stripMargin + ) + assert(err.contains("reached zero")) + } + + test("tailstrictChainedCalls") { + // Mutual recursion via object methods — Jsonnet's local bindings are sequential, + // so we use an object to allow even/odd to reference each other. + eval( + """ + |local fns = { + | even(n):: + | if n == 0 then true + | else fns.odd(n - 1) tailstrict, + | odd(n):: + | if n == 0 then false + | else fns.even(n - 1) tailstrict, + |}; + | + |fns.even(1000) + |""".stripMargin + ) ==> ujson.True + } + + // ---- Materializer integration tests ---- + // These verify that TailCall sentinels never leak into the Materializer. + // If a TailCall escapes, the Materializer would hit "Unknown value type tailcall" + // instead of producing valid JSON. + + test("materializeObjectFieldFromTailstrict") { + // Object field value computed via tailstrict recursion — Materializer must see + // the resolved Val, not a TailCall sentinel. + eval( + """ + |local sum(n, accum=0) = + | if n <= 0 then accum + | else sum(n - 1, accum + n) tailstrict; + | + |{ result: sum(100) } + |""".stripMargin + ) ==> ujson.Obj("result" -> ujson.Num(5050)) + } + + test("materializeArrayElementFromTailstrict") { + // Array element computed via tailstrict recursion — each element must be + // fully resolved before the Materializer iterates over the array. + eval( + """ + |local fib(n, a=0, b=1) = + | if n <= 0 then a + | else fib(n - 1, b, a + b) tailstrict; + | + |[fib(0), fib(1), fib(5), fib(10)] + |""".stripMargin + ) ==> ujson.Arr(ujson.Num(0), ujson.Num(1), ujson.Num(5), ujson.Num(55)) + } + + test("materializeNestedObjectFromTailstrict") { + // Deeply nested object where multiple fields are computed via tailstrict. + // Tests that the iterative Materializer stack correctly handles resolved values + // at every nesting level. + eval( + """ + |local countdown(n) = + | if n <= 0 then 0 + | else countdown(n - 1) tailstrict; + | + |{ + | outer: { + | inner: { + | value: countdown(1000), + | }, + | sibling: countdown(500), + | }, + |} + |""".stripMargin + ) ==> ujson.Obj( + "outer" -> ujson.Obj( + "inner" -> ujson.Obj("value" -> ujson.Num(0)), + "sibling" -> ujson.Num(0) + ) + ) + } + + test("materializeMixedContainerFromTailstrict") { + // Mixed array-of-objects where both container types contain tailstrict-computed values. + // Exercises the Materializer's MaterializeObjFrame/MaterializeArrFrame stack interleaving. + eval( + """ + |local double(n, accum=0) = + | if n <= 0 then accum + | else double(n - 1, accum + 2) tailstrict; + | + |[ + | { x: double(3) }, + | { x: double(5) }, + |] + |""".stripMargin + ) ==> ujson.Arr( + ujson.Obj("x" -> ujson.Num(6)), + ujson.Obj("x" -> ujson.Num(10)) + ) + } + + test("materializeLazyFieldFromTailstrict") { + // Object field that is lazily evaluated — the tailstrict call happens inside + // a Lazy thunk that is only forced when the Materializer accesses the field. + eval( + """ + |local sum(n, accum=0) = + | if n <= 0 then accum + | else sum(n - 1, accum + n) tailstrict; + | + |local obj = { a: sum(50), b: sum(100) }; + |[obj.a, obj.b] + |""".stripMargin + ) ==> ujson.Arr(ujson.Num(1275), ujson.Num(5050)) + } + + test("materializeStringifyFromTailstrict") { + // std.toString forces materialization to string — verifies TailCall is resolved + // before the Renderer visitor processes the value. + eval( + """ + |local repeat(n, s="", accum="") = + | if n <= 0 then accum + | else repeat(n - 1, s, accum + s) tailstrict; + | + |std.toString({ msg: repeat(3, "ab") }) + |""".stripMargin + ) ==> ujson.Str("""{"msg": "ababab"}""") + } + + test("tailstrictTwoArgs") { + // Apply2: two-argument tailstrict call with deep recursion — exercises the + // visitApply2 / Val.Func.apply2 code path specifically. + eval( + """ + |local gcd(a, b) = + | if b == 0 then a + | else gcd(b, a % b) tailstrict; + | + |[gcd(48, 18), gcd(100, 75), gcd(17, 13)] + |""".stripMargin + ) ==> ujson.Arr(ujson.Num(6), ujson.Num(25), ujson.Num(1)) + } + + test("tailstrictNonTailPosition") { + // tailstrict call in non-tail position (bound to a local variable). + // The call is NOT in tail position of the enclosing function, so it goes through + // visitApply* (not visitExprWithTailCallSupport). TailCall.resolve in visitApply* must + // still resolve any TailCall chain produced by the callee. + eval( + """ + |local sum(n, accum=0) = + | if n <= 0 then accum + | else sum(n - 1, accum + n) tailstrict; + | + |local result = sum(10000); + |result + 1 + |""".stripMargin + ) ==> ujson.Num(50005001) + } + + test("tailstrictBuiltinFoldl") { + // std.foldl invokes a user callback with TailstrictModeDisabled. + // The callback itself uses tailstrict recursion internally — verifies that + // Val.Func.apply* resolves TailCall before returning to the builtin. + eval( + """ + |local power(base, exp, accum=1) = + | if exp <= 0 then accum + | else power(base, exp - 1, accum * base) tailstrict; + | + |std.foldl(function(acc, x) acc + power(2, x), [0, 1, 2, 3, 4], 0) + |""".stripMargin + ) ==> ujson.Num(31) + } + + test("tailstrictReturnsContainer") { + // Tail-recursive function that returns an object/array at the base case. + // Verifies that TailCall.resolve correctly resolves to a container value + // that the Materializer can then process without issues. + eval( + """ + |local buildList(n, accum=[]) = + | if n <= 0 then accum + | else buildList(n - 1, accum + [n]) tailstrict; + | + |local buildObj(n, accum={}) = + | if n <= 0 then accum + | else buildObj(n - 1, accum { ["k" + n]: n }) tailstrict; + | + |{ + | list: buildList(5), + | obj: buildObj(3), + |} + |""".stripMargin + ) ==> ujson.Obj( + "list" -> ujson.Arr( + ujson.Num(5), + ujson.Num(4), + ujson.Num(3), + ujson.Num(2), + ujson.Num(1) + ), + "obj" -> ujson.Obj( + "k3" -> ujson.Num(3), + "k2" -> ujson.Num(2), + "k1" -> ujson.Num(1) + ) + ) + } + } +}